Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task #6644

Merged
merged 7 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,13 @@
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorWithPadding,
DataCollatorForNextSentencePrediction,
)
from .data.datasets import (
GlueDataset,
TextDataset,
LineByLineTextDataset,
TextDatasetForNextSentencePrediction,
GlueDataTrainingArguments,
SquadDataset,
SquadDataTrainingArguments,
Expand Down
184 changes: 184 additions & 0 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

Expand Down Expand Up @@ -313,3 +314,186 @@ def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor,
) & masked_indices[i]

return inputs, perm_mask, target_mapping, labels


@dataclass
class DataCollatorForNextSentencePrediction:
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
"""

tokenizer: PreTrainedTokenizer
mlm: bool = True
block_size: int = 128
short_seq_probability: float = 0.1
nsp_probability: float = 0.5
mlm_probability: float = 0.15

def __call__(self, examples: List[Union[List[List[int]], Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd need to grab the token_type_ids and the labels too I think.

Copy link
Contributor Author

@mojave-pku mojave-pku Aug 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm a little confused.
The labels you mentioned are nsp/mlm labels, or labels for a specific task?
Since none of data collators in this file grab the token_type_ids and labels, they just take the examples out of the dict, and do nothing else.
And segment_ids are generated in self.create_examples_from_document.
Thank you~

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, I was reading this wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha, ok~ :-)


input_ids = []
segment_ids = []
nsp_labels = []

for i, doc in enumerate(examples):
input_id, segment_id, label = self.create_examples_from_document(doc, i, examples)
input_ids.extend(input_id)
segment_ids.extend(segment_id)
nsp_labels.extend(label)
if self.mlm:
input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids))
else:
input_ids = self._tensorize_batch(input_ids)

return {
"input_ids": input_ids,
"token_type_ids": self._tensorize_batch(segment_ids),
"masked_lm_labels": mlm_labels if self.mlm else None,
"next_sentence_label": torch.tensor(nsp_labels),
}

def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)

def create_examples_from_document(
self, document: List[List[int]], doc_index: int, examples: List[List[List[int]]]
):
"""Creates examples for a single document."""

max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True)

# We *usually* want to fill up the entire sequence since we are padding
# to `block_size` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `block_size` is a hard limit.
target_seq_length = max_num_tokens
if random.random() < self.short_seq_probability:
target_seq_length = random.randint(2, max_num_tokens)

current_chunk = [] # a buffer stored current working segments
current_length = 0
i = 0
input_ids = []
segment_ids = []
labels = []
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = random.randint(1, len(current_chunk) - 1)

tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])

tokens_b = []

if len(current_chunk) == 1 or random.random() < self.nsp_probability:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)

# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
for _ in range(10):
random_document_index = random.randint(0, len(examples) - 1)
if random_document_index != doc_index:
break

random_document = examples[random_document_index]
random_start = random.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])

assert len(tokens_a) >= 1
assert len(tokens_b) >= 1

tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences(
tokens_a,
tokens_b,
num_tokens_to_remove=len(tokens_a) + len(tokens_b) - max_num_tokens,
truncation_strategy="longest_first",
)

input_ids.append(torch.tensor(self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)))
segment_ids.append(
torch.tensor(self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b))
)
labels.append(torch.tensor(1 if is_random_next else 0))

current_chunk = []
current_length = 0

i += 1

return input_ids, segment_ids, labels

def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""

if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)

labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
2 changes: 1 addition & 1 deletion src/transformers/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
# module, but to preserve other warnings. So, don't check this module at all.

from .glue import GlueDataset, GlueDataTrainingArguments
from .language_modeling import LineByLineTextDataset, TextDataset
from .language_modeling import LineByLineTextDataset, TextDataset, TextDatasetForNextSentencePrediction
from .squad import SquadDataset, SquadDataTrainingArguments
79 changes: 79 additions & 0 deletions src/transformers/data/datasets/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,82 @@ def __len__(self):

def __getitem__(self, i) -> torch.Tensor:
return torch.tensor(self.examples[i], dtype=torch.long)


class TextDatasetForNextSentencePrediction(Dataset):
"""
This will be superseded by a framework-agnostic approach
soon.
"""

def __init__(
self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False,
):
assert os.path.isfile(file_path), f"Input file path {file_path} not found"

block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True)

directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(
directory, "cached_lm_{}_{}_{}".format(tokenizer.__class__.__name__, str(block_size), filename,),
)

self.tokenizer = tokenizer
self.examples = []

# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"

# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
#
# Example:
# I am very happy.
# Here is the second sentence.
#
# A new document.

with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache:
start = time.time()
with open(cached_features_file, "rb") as handle:
self.examples = pickle.load(handle)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
)
else:
logger.info(f"Creating features from dataset file at {directory}")

self.examples = [[]]
with open(file_path, encoding="utf-8") as f:
while True:
line = f.readline()
if not line:
break
line = line.strip()

# Empty lines are used as document delimiters
if not line and len(self.examples[-1]) != 0:
self.examples.append([])
tokens = tokenizer.tokenize(line)
tokens = tokenizer.convert_tokens_to_ids(tokens)
if tokens:
self.examples[-1].append(tokens)

start = time.time()
with open(cached_features_file, "wb") as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)

def __len__(self):
return len(self.examples)

def __getitem__(self, i):
return self.examples[i]