Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to load text pairs with CSVClassificationCorpus #2149

Merged
merged 4 commits into from Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 52 additions & 34 deletions flair/datasets/document_classification.py
Expand Up @@ -11,7 +11,7 @@
Corpus,
Token,
FlairDataset,
Tokenizer
Tokenizer, DataPair
)
from flair.tokenization import SegtokTokenizer, SpaceTokenizer
from flair.datasets.base import find_train_dev_test_files
Expand Down Expand Up @@ -454,9 +454,12 @@ def __init__(

# most data sets have the token text in the first column, if not, pass 'text' as column
self.text_columns: List[int] = []
self.pair_columns: List[int] = []
for column in column_name_map:
if column_name_map[column] == "text":
self.text_columns.append(column)
if column_name_map[column] == "pair":
self.pair_columns.append(column)

with open(self.path_to_file, encoding=encoding) as csv_file:

Expand Down Expand Up @@ -488,33 +491,61 @@ def __init__(

if self.in_memory:

text = " ".join(
[row[text_column] for text_column in self.text_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)

for column in self.column_name_map:
column_value = row[column]
if (
self.column_name_map[column].startswith("label")
and column_value
):
if column_value != self.no_class_label:
sentence.add_label(label_type, column_value)
sentence = self._make_labeled_data_point(row)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]
self.sentences.append(sentence)

else:
self.raw_data.append(row)

self.total_sentence_count += 1

def _make_labeled_data_point(self, row):

# make sentence from text (and filter for length)
text = " ".join(
[row[text_column] for text_column in self.text_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]

# if a pair column is defined, make a sentence pair object
if len(self.pair_columns) > 0:

text = " ".join(
[row[pair_column] for pair_column in self.pair_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

pair = Sentence(text, use_tokenizer=self.tokenizer)

if 0 < self.max_tokens_per_doc < len(sentence):
pair.tokens = pair.tokens[: self.max_tokens_per_doc]

data_point = DataPair(first=sentence, second=pair)

else:
data_point = sentence

for column in self.column_name_map:
column_value = row[column]
if (
self.column_name_map[column].startswith("label")
and column_value
):
if column_value != self.no_class_label:
data_point.add_label(self.label_type, column_value)

return data_point

def is_in_memory(self) -> bool:
return self.in_memory

Expand All @@ -527,20 +558,7 @@ def __getitem__(self, index: int = 0) -> Sentence:
else:
row = self.raw_data[index]

text = " ".join([row[text_column] for text_column in self.text_columns])

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)
for column in self.column_name_map:
column_value = row[column]
if self.column_name_map[column].startswith("label") and column_value:
if column_value != self.no_class_label:
sentence.add_label(self.label_type, column_value)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]
sentence = self._make_labeled_data_point(row)

return sentence

Expand Down
5 changes: 2 additions & 3 deletions flair/datasets/sequence_labeling.py
Expand Up @@ -3526,7 +3526,6 @@ def __init__(
self,
base_path: Union[str, Path] = None,
in_memory: bool = True,
document_as_sequence: bool = False,
**corpusargs,
):
"""
Expand Down Expand Up @@ -3571,7 +3570,7 @@ def __init__(

for row in posts: # Go through all the post titles

txtout.writelines("-DOCSTART-\n") # Start each post with a -DOCSTART- token
txtout.writelines("-DOCSTART-\n\n") # Start each post with a -DOCSTART- token

# Keep track of how many and which entity mentions does a given post title have
link_annots = [] # [start pos, end pos, wiki page title] of an entity mention
Expand Down Expand Up @@ -3643,7 +3642,7 @@ def __init__(
train_file=corpus_file_name,
column_delimiter="\t",
in_memory=in_memory,
document_separator_token=None if not document_as_sequence else "-DOCSTART-",
document_separator_token="-DOCSTART-",
**corpusargs,
)

Expand Down