From 243877d066122a36684f6333bcba0dcb12178950 Mon Sep 17 00:00:00 2001 From: Romain Keramitas Date: Thu, 8 Apr 2021 14:58:17 +0200 Subject: [PATCH] Apply random case changes before expanding vocab and counting tokens Signed-off-by: Romain Keramitas --- flair/trainers/language_model_trainer.py | 25 ++++++++++-------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/flair/trainers/language_model_trainer.py b/flair/trainers/language_model_trainer.py index 3c999b5557..e6c924a4fc 100644 --- a/flair/trainers/language_model_trainer.py +++ b/flair/trainers/language_model_trainer.py @@ -62,8 +62,9 @@ def __getitem__(self, index=0) -> torch.tensor: self.files[index] = Path(self.files[index]) assert self.files[index].exists() - lines = [doc + self.document_delimiter - for doc in open(self.files[index], "r", encoding="utf-8").read().split(self.document_delimiter) if doc] + with self.files[index].open("r", encoding="utf-8") as fin: + lines = (doc + self.document_delimiter for doc in fin.read().split(self.document_delimiter) if doc) + lines = list(map(self.random_casechange, lines)) log.info(f"read text file with {len(lines)} lines") if self.shuffle: @@ -90,9 +91,6 @@ def __getitem__(self, index=0) -> torch.tensor: # charsplit file content token = 0 for line in lines: - if self.random_case_flip: - line = self.random_casechange(line) - if self.split_on_char: chars = list(line) else: @@ -107,9 +105,6 @@ def __getitem__(self, index=0) -> torch.tensor: # charsplit file content token = tokens - 1 for line in lines: - if self.random_case_flip: - line = self.random_casechange(line) - if self.split_on_char: chars = list(line) else: @@ -123,13 +118,13 @@ def __getitem__(self, index=0) -> torch.tensor: return ids - @staticmethod - def random_casechange(line: str) -> str: - no = random.randint(0, 99) - if no == 0: - line = line.lower() - if no == 1: - line = line.upper() + def random_casechange(self, line: str) -> str: + if self.random_case_flip: + no = random.randint(0, 99) + if no == 0: + line = line.lower() + if no == 1: + line = line.upper() return line