Skip to content

Commit

Permalink
Apply splitting only once and use comprehension to create ids
Browse files Browse the repository at this point in the history
Signed-off-by: Romain Keramitas <romain.keramitas@mppteam.com>
  • Loading branch information
RomainMPP committed Apr 8, 2021
1 parent 243877d commit 3fd43ca
Showing 1 changed file with 11 additions and 43 deletions.
54 changes: 11 additions & 43 deletions flair/trainers/language_model_trainer.py
Expand Up @@ -64,58 +64,26 @@ def __getitem__(self, index=0) -> torch.tensor:

with self.files[index].open("r", encoding="utf-8") as fin:
lines = (doc + self.document_delimiter for doc in fin.read().split(self.document_delimiter) if doc)
lines = list(map(self.random_casechange, lines))
lines = map(self.random_casechange, lines)
lines = list(map(list if self.split_on_char else str.split, lines))

log.info(f"read text file with {len(lines)} lines")

if self.shuffle:
random.shuffle(lines)
log.info(f"shuffled")

tokens = 0
for line in lines:

if self.split_on_char:
chars = list(line)
else:
chars = line.split()

tokens += len(chars)

# Add chars to the dictionary
if self.expand_vocab:
if self.expand_vocab:
for chars in lines:
for char in chars:
self.dictionary.add_item(char)

ids = torch.zeros(tokens, dtype=torch.long)
if self.forward:
# charsplit file content
token = 0
for line in lines:
if self.split_on_char:
chars = list(line)
else:
chars = line.split()

for char in chars:
if token >= tokens:
break
ids[token] = self.dictionary.get_idx_for_item(char)
token += 1
else:
# charsplit file content
token = tokens - 1
for line in lines:
if self.split_on_char:
chars = list(line)
else:
chars = line.split()

for char in chars:
if token >= tokens:
break
ids[token] = self.dictionary.get_idx_for_item(char)
token -= 1

ids = torch.tensor(
[self.dictionary.get_idx_for_item(char) for chars in lines for char in chars],
dtype=torch.long
)
if not self.forward:
ids = ids.flip(0)
return ids

def random_casechange(self, line: str) -> str:
Expand Down

0 comments on commit 3fd43ca

Please sign in to comment.