From 55f13db7b6c626bde53b0dc574b6fb743b3d2d27 Mon Sep 17 00:00:00 2001 From: dimakarp1996 Date: Thu, 6 Apr 2023 21:57:38 +0300 Subject: [PATCH 1/2] Added supported non-string label types to dataset_reader. --- .../basic_classification_reader.py | 55 ++++++++++++------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/deeppavlov/dataset_readers/basic_classification_reader.py b/deeppavlov/dataset_readers/basic_classification_reader.py index 8ef767b368..cce098b826 100644 --- a/deeppavlov/dataset_readers/basic_classification_reader.py +++ b/deeppavlov/dataset_readers/basic_classification_reader.py @@ -13,10 +13,14 @@ # limitations under the License. + from logging import getLogger from pathlib import Path +from collections import defaultdict import pandas as pd +from tqdm import tqdm +from overrides import overrides from deeppavlov.core.common.registry import register from deeppavlov.core.data.dataset_reader import DatasetReader @@ -31,15 +35,15 @@ class BasicClassificationDatasetReader(DatasetReader): Class provides reading dataset in .csv format """ + @overrides def read(self, data_path: str, url: str = None, format: str = "csv", class_sep: str = None, - *args, **kwargs) -> dict: + label_type: str = "str", *args, **kwargs) -> dict: """ Read dataset from data_path directory. Reading files are all data_types + extension (i.e for data_types=["train", "valid"] files "train.csv" and "valid.csv" form data_path will be read) - Args: data_path: directory with files url: download data files if data_path not exists or empty @@ -50,11 +54,17 @@ def read(self, data_path: str, url: str = None, names (array): list of column names to use orient (str): indication of expected JSON string format lines (boolean): read the file as a json object per line. Default: ``False`` - + label_type(str): expected type of labels. Default: ``"str"`` Returns: dictionary with types from data_types. Each field of dictionary is a list of tuples (x_i, y_i) """ + def row_list_process(row, y): + if pd.isna(row[y]): + return [] + else: + return [label_type(label) for label in str(row[y]).split(class_sep)] + data_types = ["train", "valid", "test"] train_file = kwargs.get('train', 'train.csv') @@ -70,6 +80,13 @@ def read(self, data_path: str, url: str = None, data = {"train": [], "valid": [], "test": []} + + supported_label_types = ['int','str','float'] + error_msg = f'Wrong label type {label_type} given! Needs to be one of the built-in Python types' + if label_type not in supported_label_types: + raise Exception(error_msg) + label_type = eval(label_type) + data=defaultdict(list) for data_type in data_types: file_name = kwargs.get(data_type, '{}.{}'.format(data_type, format)) if file_name is None: @@ -80,6 +97,7 @@ def read(self, data_path: str, url: str = None, if format == 'csv': keys = ('sep', 'header', 'names') options = {k: kwargs[k] for k in keys if k in kwargs} + print(file) df = pd.read_csv(file, **options) elif format == 'json': keys = ('orient', 'lines') @@ -90,22 +108,21 @@ def read(self, data_path: str, url: str = None, x = kwargs.get("x", "text") y = kwargs.get('y', 'labels') - if isinstance(x, list): - if class_sep is None: - # each sample is a tuple ("text", "label") - data[data_type] = [([row[x_] for x_ in x], str(row[y])) - for _, row in df.iterrows()] - else: - # each sample is a tuple ("text", ["label", "label", ...]) - data[data_type] = [([row[x_] for x_ in x], str(row[y]).split(class_sep)) - for _, row in df.iterrows()] - else: - if class_sep is None: - # each sample is a tuple ("text", "label") - data[data_type] = [(row[x], str(row[y])) for _, row in df.iterrows()] - else: - # each sample is a tuple ("text", ["label", "label", ...]) - data[data_type] = [(row[x], str(row[y]).split(class_sep)) for _, row in df.iterrows()] + + for _, row in tqdm(df.iterrows()): + try: + if isinstance(x, list): + x_text = [row[x_] for x_ in x] + else: + x_text = row[x] + if class_sep is None: + y_label = label_type(row[y]) + else: + y_label = row_list_process(row, y) + data[data_type].append((x_text, y_label)) + except Exception as e: + print(f'Error processing {row}: {e}') + raise e else: log.warning("Cannot find {} file".format(file)) From 56c713fd6bff390a075334c2b8c929d4bc7cfa71 Mon Sep 17 00:00:00 2001 From: dimakarp1996 Date: Wed, 3 May 2023 16:57:26 +0300 Subject: [PATCH 2/2] Update basic_classification_reader.py --- deeppavlov/dataset_readers/basic_classification_reader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deeppavlov/dataset_readers/basic_classification_reader.py b/deeppavlov/dataset_readers/basic_classification_reader.py index cce098b826..5b116c5cf3 100644 --- a/deeppavlov/dataset_readers/basic_classification_reader.py +++ b/deeppavlov/dataset_readers/basic_classification_reader.py @@ -20,7 +20,6 @@ import pandas as pd from tqdm import tqdm -from overrides import overrides from deeppavlov.core.common.registry import register from deeppavlov.core.data.dataset_reader import DatasetReader