Skip to content

Commit

Permalink
Apply splitting only once and use comprehension to create ids
Browse files Browse the repository at this point in the history
Signed-off-by: Romain Keramitas <r.keramitas@gmail.com>
  • Loading branch information
r0mainK committed Apr 8, 2021
1 parent 5f4748b commit cd99988
Showing 1 changed file with 10 additions and 43 deletions.
53 changes: 10 additions & 43 deletions flair/trainers/language_model_trainer.py
Expand Up @@ -66,58 +66,25 @@ def __getitem__(self, index=0) -> torch.tensor:
lines = (doc + self.document_delimiter for doc in fin.read().split(self.document_delimiter) if doc)
if self.random_case_flip:
lines = map(self.random_casechange, lines)
lines = list(lines)
lines = list(map(list if self.split_on_char else str.split, lines))

log.info(f"read text file with {len(lines)} lines")

if self.shuffle:
random.shuffle(lines)
log.info(f"shuffled")

tokens = 0
for line in lines:

if self.split_on_char:
chars = list(line)
else:
chars = line.split()

tokens += len(chars)

# Add chars to the dictionary
if self.expand_vocab:
if self.expand_vocab:
for chars in lines:
for char in chars:
self.dictionary.add_item(char)

ids = torch.zeros(tokens, dtype=torch.long)
if self.forward:
# charsplit file content
token = 0
for line in lines:
if self.split_on_char:
chars = list(line)
else:
chars = line.split()

for char in chars:
if token >= tokens:
break
ids[token] = self.dictionary.get_idx_for_item(char)
token += 1
else:
# charsplit file content
token = tokens - 1
for line in lines:
if self.split_on_char:
chars = list(line)
else:
chars = line.split()

for char in chars:
if token >= tokens:
break
ids[token] = self.dictionary.get_idx_for_item(char)
token -= 1

ids = torch.tensor(
[self.dictionary.get_idx_for_item(char) for chars in lines for char in chars],
dtype=torch.long
)
if not self.forward:
ids = ids.flip(0)
return ids

@staticmethod
Expand Down

0 comments on commit cd99988

Please sign in to comment.