Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: You are attempting to pad samples but the tokenizer you are using (GPT2Tokenizer) does not have one. #4122

Closed
2 tasks
zixiliuUSC opened this issue May 2, 2020 · 20 comments
Assignees
Labels
Ex: LM (Pretraining) Related to language modeling pre-training wontfix

Comments

@zixiliuUSC
Copy link

馃悰 Bug

Information

Model I am using (Bert, XLNet ...):GPT2

Language I am using the model on (English, Chinese ...):English

The problem arises when using:

my own modified scripts: (give details below)
python examples/run_language_modeling.py
--train_data_file=temp_gpt2/gpt2.train
--output_dir=checkpoints/gpt2
--model_type=gpt2
--model_name_or_path=gpt2
--eval_data_file=temp_gpt2/test.txt
--line_by_line
--do_train
--do_eval
--evaluate_during_training
--per_gpu_train_batch_size=20
--per_gpu_eval_batch_size=20
--gradient_accumulation_steps=1
--learning_rate=8e-5
--weight_decay=0.075
--adam_epsilon=1e-8
--warmup_steps=500
--max_grad_norm=5.0
--num_train_epochs=20
--logging_steps=500
--save_steps=500

The tasks I am working on is:

  • an official GLUE/SQUaD task: language modeling
  • my own task or dataset: (give details below) Conll2014 GEC

To reproduce

Steps to reproduce the behavior:

  1. run the script
    I get the following error:
bash train_gpt2.sh 
05/02/2020 10:14:25 - INFO - transformers.training_args -   PyTorch: setting up devices
05/02/2020 10:14:25 - WARNING - __main__ -   Process rank: -1, device: cuda, n_gpu: 2, distributed training: False, 16-bits training: False
05/02/2020 10:14:25 - INFO - __main__ -   Training/evaluation parameters TrainingArguments(output_dir='checkpoints/gpt2', overwrite_output_dir=False, do_train=True, do_eval=True, do_predict=False, evaluate_during_training=True, per_gpu_train_batch_size=20, per_gpu_eval_batch_size=20, gradient_accumulation_steps=1, learning_rate=8e-05, weight_decay=0.075, adam_epsilon=1e-08, max_grad_norm=5.0, num_train_epochs=20.0, max_steps=-1, warmup_steps=500, logging_dir=None, logging_first_step=False, logging_steps=500, save_steps=500, save_total_limit=None, no_cuda=False, seed=42, fp16=False, fp16_opt_level='O1', local_rank=-1)
05/02/2020 10:14:35 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /home/zixi/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.085d5f6a8e7812ea05ff0e6ed0645ab2e75d80387ad55c1ad9806ee70d272f80
05/02/2020 10:14:35 - INFO - transformers.configuration_utils -   Model config GPT2Config {
  "activation_function": "gelu_new",
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "vocab_size": 50257
}

05/02/2020 10:14:39 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /home/zixi/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.4c1d7fc2ac6ddabeaf0c8bec2ffc7dc112f668f5871a06efcff113d2797ec7d5
05/02/2020 10:14:39 - INFO - transformers.configuration_utils -   Model config GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "vocab_size": 50257
}

05/02/2020 10:14:42 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json from cache at /home/zixi/.cache/torch/transformers/f2808208f9bec2320371a9f5f891c184ae0b674ef866b79c58177067d15732dd.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71
05/02/2020 10:14:42 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt from cache at /home/zixi/.cache/torch/transformers/d629f792e430b3c76a1291bb2766b0a047e36fae0588f9dbc1ae51decdff691b.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
05/02/2020 10:14:43 - INFO - transformers.modeling_utils -   loading weights file https://cdn.huggingface.co/gpt2-pytorch_model.bin from cache at /home/zixi/.cache/torch/transformers/d71fd633e58263bd5e91dd3bde9f658bafd81e11ece622be6a3c2e4d42d8fd89.778cf36f5c4e5d94c8cd9cefcf2a580c8643570eb327f0d4a1f007fab2acbdf1
05/02/2020 10:14:47 - INFO - transformers.modeling_utils -   Weights of GPT2LMHeadModel not initialized from pretrained model: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
05/02/2020 10:14:47 - INFO - transformers.data.datasets.language_modeling -   Creating features from dataset file at temp_gpt2/gpt2_train.txt
05/02/2020 10:16:41 - INFO - transformers.data.datasets.language_modeling -   Creating features from dataset file at temp_gpt2/gpt2_test.txt
05/02/2020 10:16:44 - INFO - transformers.trainer -   ***** Running training *****
05/02/2020 10:16:44 - INFO - transformers.trainer -     Num examples = 1130686
05/02/2020 10:16:44 - INFO - transformers.trainer -     Num Epochs = 20
05/02/2020 10:16:44 - INFO - transformers.trainer -     Instantaneous batch size per GPU = 20
05/02/2020 10:16:44 - INFO - transformers.trainer -     Total train batch size (w. parallel, distributed & accumulation) = 40
05/02/2020 10:16:44 - INFO - transformers.trainer -     Gradient Accumulation steps = 1
05/02/2020 10:16:44 - INFO - transformers.trainer -     Total optimization steps = 565360
Epoch:   0%|                                                                                                                                                         | 0/20 [00:00<?, ?it/sTraceback (most recent call last):                                                                                                                                 | 0/28268 [00:00<?, ?it/s]
  File "examples/run_language_modeling.py", line 284, in <module>
    main()
  File "examples/run_language_modeling.py", line 254, in main
    trainer.train(model_path=model_path)
  File "/home/zixi/anaconda3/envs/ROC/lib/python3.7/site-packages/transformers/trainer.py", line 307, in train
    for step, inputs in enumerate(epoch_iterator):
  File "/home/zixi/anaconda3/envs/ROC/lib/python3.7/site-packages/tqdm/std.py", line 1107, in __iter__
    for obj in iterable:
  File "/home/zixi/anaconda3/envs/ROC/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/home/zixi/anaconda3/envs/ROC/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 385, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/zixi/anaconda3/envs/ROC/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/zixi/anaconda3/envs/ROC/lib/python3.7/site-packages/transformers/data/data_collator.py", line 91, in collate_batch
    batch = self._tensorize_batch(examples)
  File "/home/zixi/anaconda3/envs/ROC/lib/python3.7/site-packages/transformers/data/data_collator.py", line 106, in _tensorize_batch
    "You are attempting to pad samples but the tokenizer you are using"
ValueError: You are attempting to pad samples but the tokenizer you are using (GPT2Tokenizer) does not have one.
Epoch:   0%|                                                                                                                                                         | 0/20 [00:00<?, ?it/s]
Iteration:   0%|                                                           

Expected behavior

Environment info

  • transformers version:master
  • Platform:Ubuntu 18.04LTS
  • Python version:3.7
  • PyTorch version (GPU?):1.5
  • Tensorflow version (GPU?):--
  • Using GPU in script?:yes
  • Using distributed or parallel set-up in script?:No
@patrickvonplaten patrickvonplaten self-assigned this May 3, 2020
@patrickvonplaten patrickvonplaten added the Ex: LM (Pretraining) Related to language modeling pre-training label May 4, 2020
@ratthachat
Copy link
Contributor

I had the same error; I think the problem lies in "line_by_line" since if I remove this option, my code can run fine.

@patrickvonplaten
Copy link
Contributor

@BramVanroy - sorry to just link you here. Did we decide to add a force-padding option here or not yet?

@BramVanroy
Copy link
Collaborator

BramVanroy commented May 6, 2020

We had the discussion over here #3388 (comment)

@GCHQResearcher92457 mentions that they "tried to implement your suggestion". Perhaps it's better if you @patrickvonplaten could implement review the changes in the updated PR here 5ff6eb7 ?

@cahya-wirawan
Copy link
Contributor

Any progress about this issue? I have the same problem if I use line_by_line with gpt2. Thanks

@liesun1994
Copy link

Any progress about this issue? I have the same problem if I use line_by_line with gpt2. Thanks

Same problem with v2.9.1

@julien-c
Copy link
Member

Personally, I think the fix is just that you can't use the line_by_line dataset with gpt2 (because it doesn't have a padding token)

@patrickvonplaten @BramVanroy Should I just raise an error that tells user to remove the --line_by_line flag from their command?

@BramVanroy
Copy link
Collaborator

@julien-c Perhaps you can have a look at PR 5ff6eb7 there, they suggest to add a force_padding_token option so that if a model does not have a padding token by default, it is added to the vocabulary manually.

I have no preference: I like the implementation in the PR but it might not be what you would want or expect. Raising an error is also fine for me.

@VikasRajashekar
Copy link

@julien-c I am stuck with same error. If not line by line how else can I train the GPT2 model from scratch?

Here is my GPT2 config and language Model:

from transformers import GPT2LMHeadModel, GPT2Config

# Initializing a GPT2 configuration
configuration = GPT2Config(vocab_size=52_000)
model = GPT2LMHeadModel(config=configuration)
The logic for Dataset Preparation:

from transformers import LineByLineTextDataset

dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="./deu-de_web-public_2019_1M-sentences.txt",
    block_size=128,
)
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False,
)

The training logic:

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_gpu_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
    prediction_loss_only=True,
)
trainer.train()

Throws me: ValueError: You are attempting to pad samples but the tokenizer you are using (GPT2Tokenizer) does not have one.

@borisdayma
Copy link
Contributor

You need to use TextDataset and cannot use line by line at the moment.
You can build a large file and use your own special tokens if you build it completely from scratch or just reuse <|endoftext|> from pre-trained GPT-2.

@VikasRajashekar
Copy link

@borisdayma Thaks for the quick reply!

Where can I find more how these models can be trained with what kind of datasets and what kind of tokenizers and special tokens?
Also, Can this be used for the reformer too?

Please help me so that I can create simple and clear collab notebook and share it here so that others can easily use it.

@borisdayma
Copy link
Contributor

borisdayma commented Aug 12, 2020

A good place to start would be the language_modeling section of the examples page from the doc

@stale
Copy link

stale bot commented Oct 11, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Oct 11, 2020
@ffahmed
Copy link

ffahmed commented Oct 20, 2020

Hi there,
We are facing same issue.
From the thread above I am not sure what is the fix. If I dont use line_by_line, it defines its own sequence size and concatenates the sequences from multiple lines which are unrelated. How can I make it take each lines separately as sequence.

@stale stale bot removed the wontfix label Oct 20, 2020
@patrickvonplaten
Copy link
Contributor

You should set the PAD token manually equal to the EOS token. See #2630 as well

@ffahmed
Copy link

ffahmed commented Oct 21, 2020

Thanks @patrickvonplaten it worked. All i had to do is to add the following after tokenizer initialization

# bug manual fix for GPT2
# https://github.com/huggingface/transformers/issues/3021
if model_args.model_type == 'gpt2':
	tokenizer.pad_token = tokenizer.eos_token

@BramVanroy
Copy link
Collaborator

BramVanroy commented Oct 21, 2020

I'm not sure of the consequences of this. To be safe you probably also should set the IDs, then. Something like this:

tokenizer.pad_token_id = tokenizer.eos_token_id

EDIT: this is wrong, see below

@patrickvonplaten
Copy link
Contributor

@BramVanroy - it's actually not possible to set the ids equal to each other, doing tokenizer.pad_token = tokenizer.eos_token should work and is my recommend way of doing it :-)

@ffahmed
Copy link

ffahmed commented Oct 21, 2020

@patrickvonplaten yes I tried that too bot could not set the ids

@stale
Copy link

stale bot commented Dec 24, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Dec 24, 2020
@stale stale bot closed this as completed Jan 2, 2021
@abhishek0318
Copy link

abhishek0318 commented Feb 27, 2021

bug manual fix for GPT2

#3021

if model_args.model_type == 'gpt2':
tokenizer.pad_token = tokenizer.eos_token

I used this solution and the error went away. Though, this introduced a new problem for me - the model couldn't generate <|endoftext|> during inference. The model didn't learn to generate eos_token because it was ignored while computing the loss as it is same as pad_token. I had to use some other token as pad_token.

Other than this, I also had to add eos_token to each list in LineByLineDataset.examples.

Note: I am using transformers 3.4.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Ex: LM (Pretraining) Related to language modeling pre-training wontfix
Projects
None yet
Development

No branches or pull requests