Skip to content

Commit

Permalink
Apply random case changes before expanding vocab and counting tokens
Browse files Browse the repository at this point in the history
Signed-off-by: Romain Keramitas <romain.keramitas@mppteam.com>
  • Loading branch information
RomainMPP committed Apr 8, 2021
1 parent bd29e22 commit 243877d
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions flair/trainers/language_model_trainer.py
Expand Up @@ -62,8 +62,9 @@ def __getitem__(self, index=0) -> torch.tensor:
self.files[index] = Path(self.files[index])
assert self.files[index].exists()

lines = [doc + self.document_delimiter
for doc in open(self.files[index], "r", encoding="utf-8").read().split(self.document_delimiter) if doc]
with self.files[index].open("r", encoding="utf-8") as fin:
lines = (doc + self.document_delimiter for doc in fin.read().split(self.document_delimiter) if doc)
lines = list(map(self.random_casechange, lines))

log.info(f"read text file with {len(lines)} lines")
if self.shuffle:
Expand All @@ -90,9 +91,6 @@ def __getitem__(self, index=0) -> torch.tensor:
# charsplit file content
token = 0
for line in lines:
if self.random_case_flip:
line = self.random_casechange(line)

if self.split_on_char:
chars = list(line)
else:
Expand All @@ -107,9 +105,6 @@ def __getitem__(self, index=0) -> torch.tensor:
# charsplit file content
token = tokens - 1
for line in lines:
if self.random_case_flip:
line = self.random_casechange(line)

if self.split_on_char:
chars = list(line)
else:
Expand All @@ -123,13 +118,13 @@ def __getitem__(self, index=0) -> torch.tensor:

return ids

@staticmethod
def random_casechange(line: str) -> str:
no = random.randint(0, 99)
if no == 0:
line = line.lower()
if no == 1:
line = line.upper()
def random_casechange(self, line: str) -> str:
if self.random_case_flip:
no = random.randint(0, 99)
if no == 0:
line = line.lower()
if no == 1:
line = line.upper()
return line


Expand Down

0 comments on commit 243877d

Please sign in to comment.