diff --git a/flair/trainers/language_model_trainer.py b/flair/trainers/language_model_trainer.py index 1b9f82a6e..8840d7683 100644 --- a/flair/trainers/language_model_trainer.py +++ b/flair/trainers/language_model_trainer.py @@ -66,58 +66,25 @@ def __getitem__(self, index=0) -> torch.tensor: lines = (doc + self.document_delimiter for doc in fin.read().split(self.document_delimiter) if doc) if self.random_case_flip: lines = map(self.random_casechange, lines) - lines = list(lines) + lines = list(map(list if self.split_on_char else str.split, lines)) log.info(f"read text file with {len(lines)} lines") + if self.shuffle: random.shuffle(lines) log.info(f"shuffled") - tokens = 0 - for line in lines: - - if self.split_on_char: - chars = list(line) - else: - chars = line.split() - - tokens += len(chars) - - # Add chars to the dictionary - if self.expand_vocab: + if self.expand_vocab: + for chars in lines: for char in chars: self.dictionary.add_item(char) - ids = torch.zeros(tokens, dtype=torch.long) - if self.forward: - # charsplit file content - token = 0 - for line in lines: - if self.split_on_char: - chars = list(line) - else: - chars = line.split() - - for char in chars: - if token >= tokens: - break - ids[token] = self.dictionary.get_idx_for_item(char) - token += 1 - else: - # charsplit file content - token = tokens - 1 - for line in lines: - if self.split_on_char: - chars = list(line) - else: - chars = line.split() - - for char in chars: - if token >= tokens: - break - ids[token] = self.dictionary.get_idx_for_item(char) - token -= 1 - + ids = torch.tensor( + [self.dictionary.get_idx_for_item(char) for chars in lines for char in chars], + dtype=torch.long + ) + if not self.forward: + ids = ids.flip(0) return ids @staticmethod