Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added supported non-string label types to dataset_reader. #1639

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 35 additions & 19 deletions deeppavlov/dataset_readers/basic_classification_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.



from logging import getLogger
from pathlib import Path
from collections import defaultdict

import pandas as pd
from tqdm import tqdm

from deeppavlov.core.common.registry import register
from deeppavlov.core.data.dataset_reader import DatasetReader
Expand All @@ -31,15 +34,15 @@ class BasicClassificationDatasetReader(DatasetReader):
Class provides reading dataset in .csv format
"""

@overrides
def read(self, data_path: str, url: str = None,
format: str = "csv", class_sep: str = None,
*args, **kwargs) -> dict:
label_type: str = "str", *args, **kwargs) -> dict:
"""
Read dataset from data_path directory.
Reading files are all data_types + extension
(i.e for data_types=["train", "valid"] files "train.csv" and "valid.csv" form
data_path will be read)

Args:
data_path: directory with files
url: download data files if data_path not exists or empty
Expand All @@ -50,11 +53,17 @@ def read(self, data_path: str, url: str = None,
names (array): list of column names to use
orient (str): indication of expected JSON string format
lines (boolean): read the file as a json object per line. Default: ``False``

label_type(str): expected type of labels. Default: ``"str"``
Returns:
dictionary with types from data_types.
Each field of dictionary is a list of tuples (x_i, y_i)
"""
def row_list_process(row, y):
if pd.isna(row[y]):
return []
else:
return [label_type(label) for label in str(row[y]).split(class_sep)]

data_types = ["train", "valid", "test"]

train_file = kwargs.get('train', 'train.csv')
Expand All @@ -70,6 +79,13 @@ def read(self, data_path: str, url: str = None,
data = {"train": [],
"valid": [],
"test": []}

supported_label_types = ['int','str','float']
error_msg = f'Wrong label type {label_type} given! Needs to be one of the built-in Python types'
if label_type not in supported_label_types:
raise Exception(error_msg)
label_type = eval(label_type)
data=defaultdict(list)
for data_type in data_types:
file_name = kwargs.get(data_type, '{}.{}'.format(data_type, format))
if file_name is None:
Expand All @@ -80,6 +96,7 @@ def read(self, data_path: str, url: str = None,
if format == 'csv':
keys = ('sep', 'header', 'names')
options = {k: kwargs[k] for k in keys if k in kwargs}
print(file)
df = pd.read_csv(file, **options)
elif format == 'json':
keys = ('orient', 'lines')
Expand All @@ -90,22 +107,21 @@ def read(self, data_path: str, url: str = None,

x = kwargs.get("x", "text")
y = kwargs.get('y', 'labels')
if isinstance(x, list):
if class_sep is None:
# each sample is a tuple ("text", "label")
data[data_type] = [([row[x_] for x_ in x], str(row[y]))
for _, row in df.iterrows()]
else:
# each sample is a tuple ("text", ["label", "label", ...])
data[data_type] = [([row[x_] for x_ in x], str(row[y]).split(class_sep))
for _, row in df.iterrows()]
else:
if class_sep is None:
# each sample is a tuple ("text", "label")
data[data_type] = [(row[x], str(row[y])) for _, row in df.iterrows()]
else:
# each sample is a tuple ("text", ["label", "label", ...])
data[data_type] = [(row[x], str(row[y]).split(class_sep)) for _, row in df.iterrows()]

for _, row in tqdm(df.iterrows()):
try:
if isinstance(x, list):
x_text = [row[x_] for x_ in x]
else:
x_text = row[x]
if class_sep is None:
y_label = label_type(row[y])
else:
y_label = row_list_process(row, y)
data[data_type].append((x_text, y_label))
except Exception as e:
print(f'Error processing {row}: {e}')
raise e
else:
log.warning("Cannot find {} file".format(file))

Expand Down