Skip to content

Commit

Permalink
Implemented lazy line-by-line text data set loading for language mode…
Browse files Browse the repository at this point in the history
…ling including a dataset and a collator.
  • Loading branch information
GCHQResearcher92457 committed Apr 27, 2020
1 parent 4e817ff commit 5ff6eb7
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 13 deletions.
40 changes: 36 additions & 4 deletions examples/run_language_modeling.py
Expand Up @@ -33,7 +33,9 @@
AutoModelWithLMHead,
AutoTokenizer,
DataCollatorForLanguageModeling,
DataCollatorForLazyLanguageModeling,
HfArgumentParser,
LazyLineByLineTextDataset,
LineByLineTextDataset,
PreTrainedTokenizer,
TextDataset,
Expand Down Expand Up @@ -75,6 +77,12 @@ class ModelArguments:
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
force_pad_token: bool = field(
default=False,
metadata={
"help": "Whether to force the addition of a padding token to tokenizer that does not already have one."
},
)


@dataclass
Expand All @@ -94,6 +102,10 @@ class DataTrainingArguments:
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
)
lazy_loading: bool = field(
default=False,
metadata={"help": "Whether data file should be loaded lazily rather than loading all into memory up-front."},
)

mlm: bool = field(
default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
Expand All @@ -117,7 +129,9 @@ class DataTrainingArguments:

def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False, local_rank=-1):
file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line:
if args.lazy_loading:
return LazyLineByLineTextDataset(file_path)
elif args.line_by_line:
return LineByLineTextDataset(
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, local_rank=local_rank
)
Expand Down Expand Up @@ -193,6 +207,16 @@ def main():
"You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
"and load it from here, using --tokenizer_name"
)
if tokenizer.pad_token_id is None:
if model_args.force_pad_token:
# See PR 3388. Some tokenizers don't had pad tokens which causes errors at the encoding step in the collate_fn.
# We give here the option to force the addition of a pad token. The attention mask is used to ignore this token
# when feeding to the model.
tokenizer.add_special_tokens({"pad_token": "<pad>"})
else:
logger.warning(
"Attempting to train a model whose tokenizer has no padding token. This may result in errors in the encoding step. Set the --force_pad_token flag to fix this."
)

if model_args.model_name_or_path:
model = AutoModelWithLMHead.from_pretrained(
Expand Down Expand Up @@ -230,9 +254,17 @@ def main():
if training_args.do_eval
else None
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)
if data_args.lazy_loading:
data_collator = DataCollatorForLazyLanguageModeling(
tokenizer=tokenizer,
mlm=data_args.mlm,
mlm_probability=data_args.mlm_probability,
block_size=data_args.block_size,
)
else:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)

# Initialize our Trainer
trainer = Trainer(
Expand Down
21 changes: 13 additions & 8 deletions src/transformers/__init__.py
Expand Up @@ -31,7 +31,6 @@
start_memory_tracing,
stop_memory_tracing,
)

# Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
Expand Down Expand Up @@ -71,7 +70,6 @@
xnli_processors,
xnli_tasks_num_labels,
)

# Files and general utilities
from .file_utils import (
CONFIG_NAME,
Expand All @@ -89,10 +87,8 @@
is_torch_available,
)
from .hf_argparser import HfArgumentParser

# Model Cards
from .modelcard import ModelCard

# TF 2.0 <=> PyTorch conversion utilities
from .modeling_tf_pytorch_utils import (
convert_tf_weight_name_to_pt_weight_name,
Expand All @@ -103,7 +99,6 @@
load_tf2_model_in_pytorch_model,
load_tf2_weights_in_pytorch_model,
)

# Pipelines
from .pipelines import (
CsvPipelineDataFormat,
Expand All @@ -122,7 +117,6 @@
TranslationPipeline,
pipeline,
)

# Tokenizers
from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
Expand Down Expand Up @@ -325,8 +319,19 @@

# Trainer
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
from .data.data_collator import (
DefaultDataCollator,
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForLazyLanguageModeling,
)
from .data.datasets import (
GlueDataset,
TextDataset,
LineByLineTextDataset,
LazyLineByLineTextDataset,
GlueDataTrainingArguments,
)

# TensorFlow
if is_tf_available():
Expand Down
36 changes: 36 additions & 0 deletions src/transformers/data/data_collator.py
Expand Up @@ -142,3 +142,39 @@ def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels


@dataclass
class DataCollatorForLazyLanguageModeling(DataCollatorForLanguageModeling):

block_size: int = 512

def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
batch, attention_mask = self._tensorize_batch(examples)
if self.mlm:
inputs, labels = self.mask_tokens(batch)
label_key = "masked_lm_labels"
else:
inputs, labels = batch, batch
label_key = "labels"
labels = labels.masked_fill(attention_mask == 0, -100)
return {"input_ids": batch, label_key: labels}

def _tensorize_batch(self, examples: List[str]) -> torch.Tensor:

if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)

tensor_examples = self.tokenizer.batch_encode_plus(
[ex for ex in examples if ex],
max_length=self.block_size,
return_tensors="pt",
pad_to_max_length=True,
return_attention_mask=True,
)

input_ids, attention_mask = tensor_examples["input_ids"], tensor_examples["attention_mask"]
return input_ids, attention_mask
2 changes: 1 addition & 1 deletion src/transformers/data/datasets/__init__.py
Expand Up @@ -3,4 +3,4 @@
# module, but to preserve other warnings. So, don't check this module at all.

from .glue import GlueDataset, GlueDataTrainingArguments
from .language_modeling import LineByLineTextDataset, TextDataset
from .language_modeling import LazyLineByLineTextDataset, LineByLineTextDataset, TextDataset
38 changes: 38 additions & 0 deletions src/transformers/data/datasets/language_modeling.py
@@ -1,3 +1,4 @@
import linecache
import logging
import os
import pickle
Expand Down Expand Up @@ -98,3 +99,40 @@ def __len__(self):

def __getitem__(self, i) -> torch.Tensor:
return torch.tensor(self.examples[i], dtype=torch.long)


class LazyLineByLineTextDataset(Dataset):
"""
Credit: @bramvanroy for this linecache implementation.
This will be superseded by a framework-agnostic approach
soon.
"""

def __init__(self, file_path):
self.file_path = file_path
self.num_entries = self._get_n_lines(self.file_path)

@staticmethod
def _get_n_lines(fin, size=65536):
# borrowed from https://stackoverflow.com/a/9631635/1150683
def blocks(files):
while True:
b = files.read(size)
if not b:
break
yield b

with open(fin, encoding="utf-8") as fhin:
n_lines = sum(bl.count("\n") for bl in blocks(fhin))
return n_lines

def __getitem__(self, idx):
"""
:param idx (int): the index of the line to get
:return (str or None): The line as a string (newline removed) or None if there is an exception.
"""
# linecache starts counting from one, not zero, +1 the given index
return linecache.getline(self.file_path, idx + 1).rstrip()

def __len__(self):
return self.num_entries

0 comments on commit 5ff6eb7

Please sign in to comment.