diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 4c8d3649f9693..e7c41a3e7f199 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -178,10 +178,10 @@ def train_dataloader(self): return self.train_loader def val_dataloader(self): - return self.get_dataloader("dev", self.hparams.eval_batch_size) + return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False) def test_dataloader(self): - return self.get_dataloader("test", self.hparams.eval_batch_size) + return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False) def _feature_file(self, mode): return os.path.join( diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index db047fe956221..3a50c3732182c 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -6,8 +6,9 @@ Please tag @sshleifer with any issues/unexpected behaviors, or send a PR! For `bertabs` instructions, see [`bertabs/README.md`](bertabs/README.md). -### Data -XSUM Data: +## Datasets + +#### XSUM: ```bash cd examples/seq2seq wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz @@ -17,23 +18,33 @@ export XSUM_DIR=${PWD}/xsum this should make a directory called `xsum/` with files like `test.source`. To use your own data, copy that files format. Each article to be summarized is on its own line. -CNN/DailyMail data +#### CNN/DailyMail ```bash cd examples/seq2seq -wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz -tar -xzvf cnn_dm.tgz +wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm_v2.tgz +tar -xzvf cnn_dm_v2.tgz # empty lines removed +mv cnn_cln cnn_dm export CNN_DIR=${PWD}/cnn_dm -this should make a directory called `cnn_dm/` with files like `test.source`. ``` +this should make a directory called `cnn_dm/` with 6 files. -WMT16 English-Romanian Translation Data: +#### WMT16 English-Romanian Translation Data: download with this command: ```bash wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz tar -xzvf wmt_en_ro.tar.gz export ENRO_DIR=${PWD}/wmt_en_ro -this should make a directory called `wmt_en_ro/` with files like `test.source`. ``` +this should make a directory called `wmt_en_ro/` with 6 files. + +#### WMT English-German: +```bash +wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_de.tgz +tar -xzvf wmt_en_de.tar.gz +export DATA_DIR=${PWD}/wmt_en_de +``` + +#### Private Data If you are using your own data, it must be formatted as one directory with 6 files: ``` @@ -75,7 +86,8 @@ Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prep Future work/help wanted: A new dataset to support multilingual tasks. -### Command Line Options +### Finetuning Scripts +All finetuning bash scripts call finetune.py (or distillation.py) with reasonable command line arguments. They usually require extra command line arguments to work. To see all the possible command line options, run: @@ -110,6 +122,8 @@ The following command should work on a 16GB GPU: --model_name_or_path facebook/bart-large ``` +There is a starter finetuning script for pegasus at `finetune_pegasus_xsum.sh`. + ### Translation Finetuning First, follow the wmt_en_ro download instructions. diff --git a/examples/seq2seq/finetune_pegasus_xsum.sh b/examples/seq2seq/finetune_pegasus_xsum.sh index bdd4d6f9ad3e6..ec7ff98557c18 100755 --- a/examples/seq2seq/finetune_pegasus_xsum.sh +++ b/examples/seq2seq/finetune_pegasus_xsum.sh @@ -10,5 +10,5 @@ python finetune.py \ --n_val 1000 \ --val_check_interval 0.25 \ --max_source_length 512 --max_target_length 56 \ - --freeze_embeds --max_target_length 56 --label_smoothing 0.1 \ + --freeze_embeds --label_smoothing 0.1 --adafactor --task summarization_xsum \ "$@" diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index ba5150d1d5464..f36ca7101be37 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -67,7 +67,7 @@ def generate_summaries_or_translations( fout.write(hypothesis + "\n") fout.flush() fout.close() - runtime = time.time() - start_time + runtime = int(time.time() - start_time) # seconds n_obs = len(examples) return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4)) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 3747c0ac7faf0..e7c795b7c5d57 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -13,9 +13,10 @@ from torch.utils.data import DataLoader import lightning_base -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers.hf_api import HfApi from transformers.modeling_bart import shift_tokens_right -from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu +from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow from .distillation import distill_main, evaluate_checkpoint from .finetune import SummarizationModule, main @@ -116,6 +117,25 @@ def setUpClass(cls): logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks return cls + @slow + @require_torch_and_cuda + def test_hub_configs(self): + """I put require_torch_and_cuda cause I only want this to run with self-scheduled.""" + + model_list = HfApi().model_list() + org = "sshleifer" + model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] + allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"] + failures = [] + for m in model_ids: + if m in allowed_to_be_broken: + continue + try: + AutoConfig.from_pretrained(m) + except Exception: + failures.append(m) + assert not failures, f"The following models could not be loaded through AutoConfig: {failures}" + @require_multigpu def test_multigpu(self): updates = dict( diff --git a/model_cards/rdenadai/BR_BERTo/README.md b/model_cards/rdenadai/BR_BERTo/README.md index 59609b9fb71a7..594df42bf5eb2 100644 --- a/model_cards/rdenadai/BR_BERTo/README.md +++ b/model_cards/rdenadai/BR_BERTo/README.md @@ -14,13 +14,17 @@ Portuguese (Brazil) model for text inference. ## Params -Trained on a corpus of 5_258_624 sentences, with 132_807_374 non unique tokens (992_418 unique tokens). +Trained on a corpus of 6_993_330 sentences. -- Vocab size: 220_000 -- RobertaForMaskedLM size : 32 -- Num train epochs: 2 -- Time to train: ~23hs (on GCP with a Nvidia T4) +- Vocab size: 150_000 +- RobertaForMaskedLM size : 512 +- Num train epochs: 3 +- Time to train: ~10days (on GCP with a Nvidia T4) I follow the great tutorial from HuggingFace team: [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train) + +More infor here: + +[BR_BERTo](https://github.com/rdenadai/BR-BERTo) diff --git a/model_cards/zanelim/singbert-large-sg/README.md b/model_cards/zanelim/singbert-large-sg/README.md index a38f50db80e5f..e3be8882d1371 100644 --- a/model_cards/zanelim/singbert-large-sg/README.md +++ b/model_cards/zanelim/singbert-large-sg/README.md @@ -13,17 +13,17 @@ datasets: - reddit singapore, malaysia - hardwarezone widget: -- text: "die [MASK] must try" - text: "kopi c siew [MASK]" +- text: "die [MASK] must try" --- # Model name -SingBert - Bert for Singlish (SG) and Manglish (MY). +SingBert Large - Bert for Singlish (SG) and Manglish (MY). ## Model description -Similar to [SingBert](https://huggingface.co/zanelim/singbert) but initialized from [BERT large uncased (whole word masking)](https://github.com/google-research/bert#pre-trained-models), with pre-training finetuned on +Similar to [SingBert](https://huggingface.co/zanelim/singbert) but the large version, which was initialized from [BERT large uncased (whole word masking)](https://github.com/google-research/bert#pre-trained-models), with pre-training finetuned on [singlish](https://en.wikipedia.org/wiki/Singlish) and [manglish](https://en.wikipedia.org/wiki/Manglish) data. ## Intended uses & limitations diff --git a/model_cards/zanelim/singbert-lite-sg/README.md b/model_cards/zanelim/singbert-lite-sg/README.md new file mode 100644 index 0000000000000..13819e064b6fa --- /dev/null +++ b/model_cards/zanelim/singbert-lite-sg/README.md @@ -0,0 +1,168 @@ +--- +language: en +tags: +- singapore +- sg +- singlish +- malaysia +- ms +- manglish +- albert-base-v2 +license: mit +datasets: +- reddit singapore, malaysia +- hardwarezone +widget: +- text: "dont play [MASK] leh" +- text: "die [MASK] must try" +--- + +# Model name + +SingBert Lite - Bert for Singlish (SG) and Manglish (MY). + +## Model description + +Similar to [SingBert](https://huggingface.co/zanelim/singbert) but the lite-version, which was initialized from [Albert base v2](https://github.com/google-research/albert#albert), with pre-training finetuned on +[singlish](https://en.wikipedia.org/wiki/Singlish) and [manglish](https://en.wikipedia.org/wiki/Manglish) data. + +## Intended uses & limitations + +#### How to use + +```python +>>> from transformers import pipeline +>>> nlp = pipeline('fill-mask', model='zanelim/singbert-lite-sg') +>>> nlp("die [MASK] must try") + +[{'sequence': '[CLS] die die must try[SEP]', + 'score': 0.7731555700302124, + 'token': 1327, + 'token_str': '▁die'}, + {'sequence': '[CLS] die also must try[SEP]', + 'score': 0.04763784259557724, + 'token': 67, + 'token_str': '▁also'}, + {'sequence': '[CLS] die still must try[SEP]', + 'score': 0.01859409362077713, + 'token': 174, + 'token_str': '▁still'}, + {'sequence': '[CLS] die u must try[SEP]', + 'score': 0.015824034810066223, + 'token': 287, + 'token_str': '▁u'}, + {'sequence': '[CLS] die is must try[SEP]', + 'score': 0.011271446943283081, + 'token': 25, + 'token_str': '▁is'}] + +>>> nlp("dont play [MASK] leh") + +[{'sequence': '[CLS] dont play play leh[SEP]', + 'score': 0.4365769624710083, + 'token': 418, + 'token_str': '▁play'}, + {'sequence': '[CLS] dont play punk leh[SEP]', + 'score': 0.06880936771631241, + 'token': 6769, + 'token_str': '▁punk'}, + {'sequence': '[CLS] dont play game leh[SEP]', + 'score': 0.051739856600761414, + 'token': 250, + 'token_str': '▁game'}, + {'sequence': '[CLS] dont play games leh[SEP]', + 'score': 0.045703962445259094, + 'token': 466, + 'token_str': '▁games'}, + {'sequence': '[CLS] dont play around leh[SEP]', + 'score': 0.013458190485835075, + 'token': 140, + 'token_str': '▁around'}] + +>>> nlp("catch no [MASK]") + +[{'sequence': '[CLS] catch no ball[SEP]', + 'score': 0.6197211146354675, + 'token': 1592, + 'token_str': '▁ball'}, + {'sequence': '[CLS] catch no balls[SEP]', + 'score': 0.08441998809576035, + 'token': 7152, + 'token_str': '▁balls'}, + {'sequence': '[CLS] catch no joke[SEP]', + 'score': 0.0676785409450531, + 'token': 8186, + 'token_str': '▁joke'}, + {'sequence': '[CLS] catch no?[SEP]', + 'score': 0.040638409554958344, + 'token': 60, + 'token_str': '?'}, + {'sequence': '[CLS] catch no one[SEP]', + 'score': 0.03546864539384842, + 'token': 53, + 'token_str': '▁one'}] + +>>> nlp("confirm plus [MASK]") + +[{'sequence': '[CLS] confirm plus chop[SEP]', + 'score': 0.9608421921730042, + 'token': 17144, + 'token_str': '▁chop'}, + {'sequence': '[CLS] confirm plus guarantee[SEP]', + 'score': 0.011784233152866364, + 'token': 9120, + 'token_str': '▁guarantee'}, + {'sequence': '[CLS] confirm plus confirm[SEP]', + 'score': 0.010571340098977089, + 'token': 10265, + 'token_str': '▁confirm'}, + {'sequence': '[CLS] confirm plus egg[SEP]', + 'score': 0.0033525123726576567, + 'token': 6387, + 'token_str': '▁egg'}, + {'sequence': '[CLS] confirm plus bet[SEP]', + 'score': 0.0008760977652855217, + 'token': 5676, + 'token_str': '▁bet'}] + +``` + +Here is how to use this model to get the features of a given text in PyTorch: +```python +from transformers import AlbertTokenizer, AlbertModel +tokenizer = AlbertTokenizer.from_pretrained('zanelim/singbert-lite-sg') +model = AlbertModel.from_pretrained("zanelim/singbert-lite-sg") +text = "Replace me by any text you'd like." +encoded_input = tokenizer(text, return_tensors='pt') +output = model(**encoded_input) +``` + +and in TensorFlow: +```python +from transformers import AlbertTokenizer, TFAlbertModel +tokenizer = AlbertTokenizer.from_pretrained("zanelim/singbert-lite-sg") +model = TFAlbertModel.from_pretrained("zanelim/singbert-lite-sg") +text = "Replace me by any text you'd like." +encoded_input = tokenizer(text, return_tensors='tf') +output = model(encoded_input) +``` + +#### Limitations and bias +This model was finetuned on colloquial Singlish and Manglish corpus, hence it is best applied on downstream tasks involving the main +constituent languages- english, mandarin, malay. Also, as the training data is mainly from forums, beware of existing inherent bias. + +## Training data +Colloquial singlish and manglish (both are a mixture of English, Mandarin, Tamil, Malay, and other local dialects like Hokkien, Cantonese or Teochew) +corpus. The corpus is collected from subreddits- `r/singapore` and `r/malaysia`, and forums such as `hardwarezone`. + +## Training procedure + +Initialized with [albert base v2](https://github.com/google-research/albert#albert) vocab and checkpoints (pre-trained weights). + +Pre-training was further finetuned on training data with the following hyperparameters +* train_batch_size: 4096 +* max_seq_length: 128 +* num_train_steps: 125000 +* num_warmup_steps: 5000 +* learning_rate: 0.00176 +* hardware: TPU v3-8 diff --git a/model_cards/zanelim/singbert/README.md b/model_cards/zanelim/singbert/README.md index 641f8facc9595..bd5a0f96f20e3 100644 --- a/model_cards/zanelim/singbert/README.md +++ b/model_cards/zanelim/singbert/README.md @@ -13,8 +13,8 @@ datasets: - reddit singapore, malaysia - hardwarezone widget: -- text: "die [MASK] must try" - text: "kopi c siew [MASK]" +- text: "die [MASK] must try" --- # Model name diff --git a/notebooks/03-pipelines.ipynb b/notebooks/03-pipelines.ipynb index 53c22634ec6fd..2a346c7ec7c83 100644 --- a/notebooks/03-pipelines.ipynb +++ b/notebooks/03-pipelines.ipynb @@ -2358,7 +2358,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { @@ -3402,4 +3402,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9558fb457e66e..502da78555768 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -200,6 +200,7 @@ from .data.data_collator import ( DataCollator, DataCollatorForLanguageModeling, + DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, DataCollatorWithPadding, default_data_collator, @@ -211,6 +212,7 @@ SquadDataset, SquadDataTrainingArguments, TextDataset, + TextDatasetForNextSentencePrediction, ) from .generation_utils import top_k_top_p_filtering from .modeling_albert import ( diff --git a/src/transformers/configuration_pegasus.py b/src/transformers/configuration_pegasus.py index 4c3564fd1062c..694759b7edb97 100644 --- a/src/transformers/configuration_pegasus.py +++ b/src/transformers/configuration_pegasus.py @@ -47,46 +47,23 @@ activation_function="relu", ) # Config values that vary between checkpoints: for testing and conversion -max_gen_length = { - # See appendix C of paper - "xsum": 64, - "cnn_dailymail": 128, - "newsroom": 128, - "wikihow": 256, - "multi_news": 256, - "reddit_tifu": 128, - "big_patent": 256, - "arxiv": 256, - "pubmed": 256, - "gigaword": 32, - "aeslc": 32, - "billsum": 256, - "large": 256, # @sshleifer chose arbitrarily +task_specific_params = { + # These are task specific params for pegasus-large and normal params for finetuned checkpoints + "summarization_xsum": {"length_penalty": 0.8, "max_length": 64, "max_position_embeddings": 512}, + "summarization_cnn_dailymail": {"length_penalty": 0.8, "max_length": 128, "max_position_embeddings": 1024}, + "summarization_newsroom": {"length_penalty": 0.8, "max_length": 128, "max_position_embeddings": 512}, + "summarization_wikihow": {"length_penalty": 0.6, "max_length": 256, "max_position_embeddings": 512}, + "summarization_multi_news": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024}, + "summarization_reddit_tifu": {"length_penalty": 0.6, "max_length": 128, "max_position_embeddings": 512}, + "summarization_big_patent": {"length_penalty": 0.7, "max_length": 256, "max_position_embeddings": 1024}, + "summarization_arxiv": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024}, + "summarization_pubmed": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024}, + "summarization_gigaword": {"length_penalty": 0.6, "max_length": 32, "max_position_embeddings": 128}, + "summarization_aeslc": {"length_penalty": 0.6, "max_length": 32, "max_position_embeddings": 512}, + "summarization_billsum": {"length_penalty": 0.6, "max_length": 256, "max_position_embeddings": 1024}, + # this last entry is useless -- just for consistency + "summarization_large": {"length_penalty": 0.8, "max_length": 256, "max_position_embeddings": 1024}, } -max_model_length = { - "xsum": 512, - "cnn_dailymail": 1024, - "newsroom": 512, - "wikihow": 512, - "multi_news": 1024, - "reddit_tifu": 512, - "big_patent": 1024, - "arxiv": 1024, - "pubmed": 1024, - "gigaword": 128, - "aeslc": 512, - "billsum": 1024, - "large": 1024, -} -expected_alpha = { - "multinews": 0.9, - "wikihow": 0.6, - "reddit_tifu": 0.6, - "big_patent": 0.7, - "gigaword": 0.6, - "aeslc": 0.6, - "billsum": 0.6, -} # otherwise 0.8 @add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC) diff --git a/src/transformers/convert_pegasus_tf_to_pytorch.py b/src/transformers/convert_pegasus_tf_to_pytorch.py index e3b8614d4ef41..edf0498f37308 100644 --- a/src/transformers/convert_pegasus_tf_to_pytorch.py +++ b/src/transformers/convert_pegasus_tf_to_pytorch.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import os from pathlib import Path from typing import Dict @@ -22,7 +23,7 @@ from tqdm import tqdm from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer -from transformers.configuration_pegasus import DEFAULTS, expected_alpha, max_gen_length, max_model_length +from transformers.configuration_pegasus import DEFAULTS, task_specific_params PATTERNS = [ @@ -101,23 +102,25 @@ def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict: return tf_weights -def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir): +def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str): # save tokenizer first dataset = Path(ckpt_path).parent.name - desired_max_model_length = max_model_length[dataset] + desired_max_model_length = task_specific_params[f"summarization_{dataset}"]["max_position_embeddings"] tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length) assert tok.model_max_length == desired_max_model_length tok.save_pretrained(save_dir) # convert model tf_weights = get_tf_weights_as_numpy(ckpt_path) - cfg_updates = dict( - max_length=max_gen_length[dataset], - length_penalty=expected_alpha.get(dataset, 0.8), - max_position_embeddings=desired_max_model_length, - ) + cfg_updates = task_specific_params[f"summarization_{dataset}"] + if dataset == "large": + cfg_updates["task_specific_params"] = task_specific_params torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates) torch_model.save_pretrained(save_dir) + sd = torch_model.state_dict() + sd.pop("model.decoder.embed_positions.weight") + sd.pop("model.encoder.embed_positions.weight") + torch.save(sd, Path(save_dir) / "pytorch_model.bin") if __name__ == "__main__": @@ -127,5 +130,6 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir): parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.") args = parser.parse_args() if args.save_dir is None: - args.save_dir = f"pegasus/{Path(args.tf_ckpt_path).parent.name}" + dataset = Path(args.tf_ckpt_path).parent.name + args.save_dir = os.path.join("pegasus", dataset) convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index b14d06d4fb132..ceb36ed74f6e1 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union @@ -327,3 +328,200 @@ def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, ) & masked_indices[i] return inputs, perm_mask, target_mapping, labels + + +@dataclass +class DataCollatorForNextSentencePrediction: + """ + Data collator used for language modeling. + - collates batches of tensors, honoring their tokenizer's pad_token + - preprocesses batches for masked language modeling + """ + + tokenizer: PreTrainedTokenizer + mlm: bool = True + block_size: int = 512 + short_seq_probability: float = 0.1 + nsp_probability: float = 0.5 + mlm_probability: float = 0.15 + + def __call__(self, examples: List[Union[List[List[int]], Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + if isinstance(examples[0], (dict, BatchEncoding)): + examples = [e["input_ids"] for e in examples] + + input_ids = [] + segment_ids = [] + attention_masks = [] + nsp_labels = [] + + for i, doc in enumerate(examples): + input_id, segment_id, attention_mask, label = self.create_examples_from_document(doc, i, examples) + input_ids.extend(input_id) + segment_ids.extend(segment_id) + attention_masks.extend(attention_mask) + nsp_labels.extend(label) + if self.mlm: + input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids)) + else: + input_ids = self._tensorize_batch(input_ids) + + return { + "input_ids": input_ids, + "attention_mask": self._tensorize_batch(attention_masks), + "token_type_ids": self._tensorize_batch(segment_ids), + "masked_lm_labels": mlm_labels if self.mlm else None, + "next_sentence_label": torch.tensor(nsp_labels), + } + + def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: + length_of_first = examples[0].size(0) + are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) + if are_tensors_same_length: + return torch.stack(examples, dim=0) + else: + if self.tokenizer._pad_token is None: + raise ValueError( + "You are attempting to pad samples but the tokenizer you are using" + f" ({self.tokenizer.__class__.__name__}) does not have one." + ) + return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + + def create_examples_from_document( + self, document: List[List[int]], doc_index: int, examples: List[List[List[int]]] + ): + """Creates examples for a single document.""" + + max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True) + + # We *usually* want to fill up the entire sequence since we are padding + # to `block_size` anyways, so short sequences are generally wasted + # computation. However, we *sometimes* + # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter + # sequences to minimize the mismatch between pre-training and fine-tuning. + # The `target_seq_length` is just a rough target however, whereas + # `block_size` is a hard limit. + target_seq_length = max_num_tokens + if random.random() < self.short_seq_probability: + target_seq_length = random.randint(2, max_num_tokens) + + current_chunk = [] # a buffer stored current working segments + current_length = 0 + i = 0 + input_ids = [] + segment_ids = [] + attention_masks = [] + labels = [] + while i < len(document): + segment = document[i] + current_chunk.append(segment) + current_length += len(segment) + if i == len(document) - 1 or current_length >= target_seq_length: + if current_chunk: + # `a_end` is how many segments from `current_chunk` go into the `A` + # (first) sentence. + a_end = 1 + if len(current_chunk) >= 2: + a_end = random.randint(1, len(current_chunk) - 1) + + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + + tokens_b = [] + + if len(current_chunk) == 1 or random.random() < self.nsp_probability: + is_random_next = True + target_b_length = target_seq_length - len(tokens_a) + + # This should rarely go for more than one iteration for large + # corpora. However, just to be careful, we try to make sure that + # the random document is not the same as the document + # we're processing. + for _ in range(10): + random_document_index = random.randint(0, len(examples) - 1) + if random_document_index != doc_index: + break + + random_document = examples[random_document_index] + random_start = random.randint(0, len(random_document) - 1) + for j in range(random_start, len(random_document)): + tokens_b.extend(random_document[j]) + if len(tokens_b) >= target_b_length: + break + # We didn't actually use these segments so we "put them back" so + # they don't go to waste. + num_unused_segments = len(current_chunk) - a_end + i -= num_unused_segments + # Actual next + else: + is_random_next = False + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + + assert len(tokens_a) >= 1 + assert len(tokens_b) >= 1 + + tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences( + tokens_a, + tokens_b, + num_tokens_to_remove=len(tokens_a) + len(tokens_b) - max_num_tokens, + truncation_strategy="longest_first", + ) + + input_id = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b) + attention_mask = [1] * len(input_id) + segment_id = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b) + assert len(input_id) <= self.block_size + + # pad + while len(input_id) < self.block_size: + input_id.append(0) + attention_mask.append(0) + segment_id.append(0) + + input_ids.append(torch.tensor(input_id)) + segment_ids.append(torch.tensor(segment_id)) + attention_masks.append(torch.tensor(attention_mask)) + labels.append(torch.tensor(1 if is_random_next else 0)) + + current_chunk = [] + current_length = 0 + + i += 1 + + return input_ids, segment_ids, attention_masks, labels + + def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + """ + + if self.tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." + ) + + labels = inputs.clone() + # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) + probability_matrix = torch.full(labels.shape, self.mlm_probability) + special_tokens_mask = [ + self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() + ] + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + if self.tokenizer._pad_token is not None: + padding_mask = labels.eq(self.tokenizer.pad_token_id) + probability_matrix.masked_fill_(padding_mask, value=0.0) + masked_indices = torch.bernoulli(probability_matrix).bool() + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels diff --git a/src/transformers/data/datasets/__init__.py b/src/transformers/data/datasets/__init__.py index ca2ab15e43fbe..f4e2aac5e968c 100644 --- a/src/transformers/data/datasets/__init__.py +++ b/src/transformers/data/datasets/__init__.py @@ -3,5 +3,5 @@ # module, but to preserve other warnings. So, don't check this module at all. from .glue import GlueDataset, GlueDataTrainingArguments -from .language_modeling import LineByLineTextDataset, TextDataset +from .language_modeling import LineByLineTextDataset, TextDataset, TextDatasetForNextSentencePrediction from .squad import SquadDataset, SquadDataTrainingArguments diff --git a/src/transformers/data/datasets/language_modeling.py b/src/transformers/data/datasets/language_modeling.py index 71a59500317e8..1a377a60b155b 100644 --- a/src/transformers/data/datasets/language_modeling.py +++ b/src/transformers/data/datasets/language_modeling.py @@ -109,3 +109,91 @@ def __len__(self): def __getitem__(self, i) -> torch.Tensor: return torch.tensor(self.examples[i], dtype=torch.long) + + +class TextDatasetForNextSentencePrediction(Dataset): + """ + This will be superseded by a framework-agnostic approach + soon. + """ + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + file_path: str, + block_size: int, + overwrite_cache=False, + ): + assert os.path.isfile(file_path), f"Input file path {file_path} not found" + + block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True) + + directory, filename = os.path.split(file_path) + cached_features_file = os.path.join( + directory, + "cached_nsp_{}_{}_{}".format( + tokenizer.__class__.__name__, + str(block_size), + filename, + ), + ) + + self.tokenizer = tokenizer + self.examples = [] + + # Make sure only the first process in distributed training processes the dataset, + # and the others will use the cache. + lock_path = cached_features_file + ".lock" + + # Input file format: + # (1) One sentence per line. These should ideally be actual sentences, not + # entire paragraphs or arbitrary spans of text. (Because we use the + # sentence boundaries for the "next sentence prediction" task). + # (2) Blank lines between documents. Document boundaries are needed so + # that the "next sentence prediction" task doesn't span between documents. + # + # Example: + # I am very happy. + # Here is the second sentence. + # + # A new document. + + with FileLock(lock_path): + if os.path.exists(cached_features_file) and not overwrite_cache: + start = time.time() + with open(cached_features_file, "rb") as handle: + self.examples = pickle.load(handle) + logger.info( + f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start + ) + else: + logger.info(f"Creating features from dataset file at {directory}") + + self.examples = [[]] + with open(file_path, encoding="utf-8") as f: + while True: + line = f.readline() + if not line: + break + line = line.strip() + + # Empty lines are used as document delimiters + if not line and len(self.examples[-1]) != 0: + self.examples.append([]) + tokens = tokenizer.tokenize(line) + tokens = tokenizer.convert_tokens_to_ids(tokens) + if tokens: + self.examples[-1].append(tokens) + + start = time.time() + with open(cached_features_file, "wb") as handle: + pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) + logger.info( + "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start + ) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + return self.examples[i] diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index ca08a383eb5ef..152e6cae6804d 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -757,7 +757,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict total=total, initial=resume_size, desc="Downloading", - disable=bool(logging.get_verbosity() > logging.NOTSET), + disable=bool(logging.get_verbosity() == logging.NOTSET), ) for chunk in response.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index b4ee37f55e13e..138b1a2f48a63 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -225,11 +225,7 @@ class EncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = SelfAttention( - self.embed_dim, - config.encoder_attention_heads, - dropout=config.attention_dropout, - ) + self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout) self.normalize_before = config.normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -377,7 +373,8 @@ class DecoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = SelfAttention( + + self.self_attn = Attention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -388,7 +385,7 @@ def __init__(self, config: BartConfig): self.normalize_before = config.normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) - self.encoder_attn = SelfAttention( + self.encoder_attn = Attention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -586,7 +583,7 @@ def forward( if use_cache: next_decoder_cache.append(layer_past.copy()) - if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart + if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART) x = self.layer_norm(x) if output_attentions: all_self_attns += (layer_self_attn,) @@ -616,7 +613,7 @@ def _reorder_buffer(attn_cache, new_order): return attn_cache -class SelfAttention(nn.Module): +class Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( diff --git a/src/transformers/modeling_tf_flaubert.py b/src/transformers/modeling_tf_flaubert.py index 9a0cc9c26c5a2..792d5d3c7348a 100644 --- a/src/transformers/modeling_tf_flaubert.py +++ b/src/transformers/modeling_tf_flaubert.py @@ -296,7 +296,7 @@ def call( else: tensor_normalized = self.layer_norm1[i](tensor) attn_outputs = self.attentions[i]( - tensor_normalized, attn_mask, None, cache, head_mask[i], training=training + tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training ) attn = attn_outputs[0] if output_attentions: diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 12da9a32dd517..297e3e791a8ed 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -346,7 +346,7 @@ class Adafactor(Optimizer): If True, learning rate is scaled by root mean square relative_step (:obj:`bool`, `optional`, defaults to :obj:`True`): If True, time-dependent learning rate is computed instead of external learning rate - warmup_init (:obj:`bool`, `optional`, defaults to False): + warmup_init (:obj:`bool`, `optional`, defaults to :obj:`False`): Time-dependent learning rate computation depends on whether warm-up initialization is being used This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 0c9966a48ed7d..571fabe69016f 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -372,6 +372,5 @@ def prepare_seq2seq_batch( **kwargs, ) model_inputs["labels"] = labels_and_decoder_mask["input_ids"] - model_inputs["decoder_attention_mask"] = labels_and_decoder_mask["attention_mask"] self.prefix_tokens = [] return model_inputs diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 656a5dd8ff1e9..74b00e7d1d53f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -644,7 +644,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: - self.global_step = int(model_path.split("-")[-1].split("/")[0]) + self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = self.global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps @@ -658,8 +658,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.global_step = 0 logger.info(" Starting fine-tuning.") - tr_loss = 0.0 - logging_loss = 0.0 + tr_loss = torch.tensor(0.0).to(self.args.device) + logging_loss_scalar = 0.0 model.zero_grad() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm) @@ -720,14 +720,15 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.global_step == 1 and self.args.logging_first_step ): logs: Dict[str, float] = {} - logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps + tr_loss_scalar = tr_loss.item() + logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps # backward compatibility for pytorch schedulers logs["learning_rate"] = ( self.lr_scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else self.lr_scheduler.get_lr()[0] ) - logging_loss = tr_loss + logging_loss_scalar = tr_loss_scalar self.log(logs) @@ -773,8 +774,6 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D break epoch_pbar.close() train_pbar.update(1) - if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: - break if self.args.tpu_metrics_debug or self.args.debug: if is_torch_tpu_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) @@ -784,6 +783,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) + if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: + break train_pbar.close() if self.tb_writer: @@ -793,7 +794,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") - return TrainOutput(self.global_step, tr_loss / self.global_step) + return TrainOutput(self.global_step, tr_loss.item() / self.global_step) def hyperparameter_search( self, @@ -973,7 +974,7 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s return inputs - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float: + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: """ Perform a training step on a batch of inputs. @@ -989,7 +990,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, argument :obj:`labels`. Check your model's documentation for all accepted arguments. Return: - :obj:`float`: The training loss on this batch. + :obj:`torch.Tensor`: The tensor with training loss on this batch. """ if hasattr(self, "_training_step"): warnings.warn( @@ -1027,7 +1028,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, else: loss.backward() - return loss.item() + return loss def is_local_master(self) -> bool: """ @@ -1276,6 +1277,10 @@ def prediction_loop( preds = xm.mesh_reduce("eval_preds", preds, torch.cat) if label_ids is not None: label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) + if eval_losses is not None: + eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist() + if samples_count is not None: + samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist()) # Finally, turn the aggregated tensors into numpy arrays. if preds is not None: diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index 0be08d78d18c0..d592d43eacbf2 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -54,7 +54,7 @@ def _configure_library_root_logger() -> None: # Apply our default configuration to the library root logger. library_root_logger = _get_library_root_logger() library_root_logger.addHandler(_default_handler) - library_root_logger.setLevel(logging.INFO) + library_root_logger.setLevel(logging.WARN) library_root_logger.propagate = False diff --git a/tests/test_data_collator.py b/tests/test_data_collator.py index 41b3b371b944e..2ec65e573807d 100644 --- a/tests/test_data_collator.py +++ b/tests/test_data_collator.py @@ -9,11 +9,13 @@ from transformers import ( DataCollatorForLanguageModeling, + DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, GlueDataset, GlueDataTrainingArguments, LineByLineTextDataset, TextDataset, + TextDatasetForNextSentencePrediction, default_data_collator, ) @@ -150,3 +152,19 @@ def test_plm(self): with self.assertRaises(ValueError): # Expect error due to odd sequence length data_collator(example) + + def test_nsp(self): + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + data_collator = DataCollatorForNextSentencePrediction(tokenizer) + + dataset = TextDatasetForNextSentencePrediction(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) + examples = [dataset[i] for i in range(len(dataset))] + batch = data_collator(examples) + self.assertIsInstance(batch, dict) + + # Since there are randomly generated false samples, the total number of samples is not fixed. + total_samples = batch["input_ids"].shape[0] + self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512))) + self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512))) + self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512))) + self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,))) diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 6fb387daa7614..68cb5d6e0462f 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -1,9 +1,10 @@ import unittest from transformers import AutoConfig, AutoTokenizer, is_torch_available -from transformers.configuration_pegasus import max_gen_length, max_model_length +from transformers.configuration_pegasus import task_specific_params from transformers.file_utils import cached_property from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils.logging import ERROR, set_verbosity from .test_modeling_bart import PGE_ARTICLE from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest @@ -14,6 +15,8 @@ XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """ +set_verbosity(ERROR) + @require_torch class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): @@ -50,31 +53,25 @@ def test_pegasus_xsum_summary(self): class PegasusConfigTests(unittest.TestCase): - def test_all_config_max_lengths(self): + @slow + def test_task_specific_params(self): + """Test that task_specific params['summarization_xsum'] == config['pegasus_xsum'] """ failures = [] pegasus_prefix = "google/pegasus" - for dataset, max_len in max_gen_length.items(): + n_prefix_chars = len("summarization_") + for task, desired_settings in task_specific_params.items(): + dataset = task[n_prefix_chars:] mname = f"{pegasus_prefix}-{dataset}" cfg = AutoConfig.from_pretrained(mname) - - if cfg.max_length != max_len: - failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}") - - if cfg.max_position_embeddings < max_model_length[dataset]: - # otherwise you get IndexError for e.g. position 513 - # see https://github.com/huggingface/transformers/issues/6599 - failures.append( - f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}" - ) - + for k, v in desired_settings.items(): + actual_value = getattr(cfg, k) + if actual_value != v: + failures.append(f"config for {mname} had {k}: {actual_value}, expected {v}") tokenizer = AutoTokenizer.from_pretrained(mname) - if max_model_length[dataset] != tokenizer.model_max_length: - failures.append( - f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}" - ) + n_pos_embeds = desired_settings["max_position_embeddings"] + if n_pos_embeds != tokenizer.model_max_length: + failures.append(f"tokenizer.model_max_length {tokenizer.model_max_length} expected {n_pos_embeds}") - if failures == []: - return # error all_fails = "\n".join(failures) - raise AssertionError(f"The following configs have unexpected settings: {all_fails}") + assert not failures, f"The following configs have unexpected settings: {all_fails}" diff --git a/tests/test_tokenization_bart.py b/tests/test_tokenization_bart.py index 59fe1786dab7b..bbd448b24ac13 100644 --- a/tests/test_tokenization_bart.py +++ b/tests/test_tokenization_bart.py @@ -69,12 +69,12 @@ def default_tokenizer_fast(self): @require_torch def test_prepare_seq2seq_batch(self): - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] tgt_text = [ "Summary of the text.", "Another summary.", ] - expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2] + expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2] for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: batch = tokenizer.prepare_seq2seq_batch( @@ -82,8 +82,8 @@ def test_prepare_seq2seq_batch(self): ) self.assertIsInstance(batch, BatchEncoding) - self.assertEqual((2, 10), batch.input_ids.shape) - self.assertEqual((2, 10), batch.attention_mask.shape) + self.assertEqual((2, 9), batch.input_ids.shape) + self.assertEqual((2, 9), batch.attention_mask.shape) result = batch.input_ids.tolist()[0] self.assertListEqual(expected_src_tokens, result) # Test that special tokens are reset @@ -91,7 +91,7 @@ def test_prepare_seq2seq_batch(self): # Test Prepare Seq @require_torch def test_seq2seq_batch_empty_target_text(self): - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt") # check if input_ids are returned and no labels @@ -102,7 +102,7 @@ def test_seq2seq_batch_empty_target_text(self): @require_torch def test_seq2seq_batch_max_target_length(self): - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] tgt_text = [ "Summary of the text.", "Another summary.", @@ -131,7 +131,7 @@ def test_seq2seq_batch_not_longer_than_maxlen(self): @require_torch def test_special_tokens(self): - src_text = ["A long paragraph for summrization."] + src_text = ["A long paragraph for summarization."] tgt_text = [ "Summary of the text.", ] diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index ba466c45d56f3..0eb7b0dd701cc 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -14,7 +14,7 @@ RobertaTokenizer, is_torch_available, ) -from transformers.testing_utils import get_tests_dir, require_torch +from transformers.testing_utils import get_tests_dir from transformers.tokenization_distilbert import DistilBertTokenizerFast from transformers.tokenization_openai import OpenAIGPTTokenizerFast from transformers.tokenization_roberta import RobertaTokenizerFast diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 16bf536b25bd1..05424ab834da7 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -120,12 +120,12 @@ def test_eos_treatment(self): def test_prepare_seq2seq_batch(self): tokenizer = self.t5_base_tokenizer - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] tgt_text = [ "Summary of the text.", "Another summary.", ] - expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, tokenizer.eos_token_id] + expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] batch = tokenizer.prepare_seq2seq_batch( src_text, tgt_texts=tgt_text, @@ -135,15 +135,15 @@ def test_prepare_seq2seq_batch(self): result = list(batch.input_ids.numpy()[0]) self.assertListEqual(expected_src_tokens, result) - self.assertEqual((2, 10), batch.input_ids.shape) - self.assertEqual((2, 10), batch.attention_mask.shape) + self.assertEqual((2, 9), batch.input_ids.shape) + self.assertEqual((2, 9), batch.attention_mask.shape) # Test that special tokens are reset self.assertEqual(tokenizer.prefix_tokens, []) def test_empty_target_text(self): tokenizer = self.t5_base_tokenizer - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK) # check if input_ids are returned and no decoder_input_ids self.assertIn("input_ids", batch) @@ -153,7 +153,7 @@ def test_empty_target_text(self): def test_max_target_length(self): tokenizer = self.t5_base_tokenizer - src_text = ["A short paragraph for summrization.", "Another short paragraph for summrization."] + src_text = ["A short paragraph for summarization.", "Another short paragraph for summarization."] tgt_text = [ "Summary of the text.", "Another summary.", @@ -162,14 +162,12 @@ def test_max_target_length(self): src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK ) self.assertEqual(32, batch["labels"].shape[1]) - self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) # test None max_target_length batch = tokenizer.prepare_seq2seq_batch( src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK ) self.assertEqual(32, batch["labels"].shape[1]) - self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) def test_outputs_not_longer_than_maxlen(self): tokenizer = self.t5_base_tokenizer @@ -182,9 +180,9 @@ def test_outputs_not_longer_than_maxlen(self): def test_eos_in_input(self): tokenizer = self.t5_base_tokenizer - src_text = ["A long paragraph for summrization. "] + src_text = ["A long paragraph for summarization. "] tgt_text = ["Summary of the text. "] - expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1] + expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1] expected_tgt_tokens = [0, 20698, 13, 8, 1499, 5, 1] batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)