New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss mask for fine-tuning GPT2LMHeadModel model #7135
Comments
It has already been mentioned here #2001 (see "Bug: Padded tokens are not excluded from the loss" session). |
Hi GPT-2 has no pad token so you can either introduce new pad token or set the eos toke as pad token and then set the pad tokens in |
Thanks. Just get aware of the |
For fine-tuning the GPT2 model, it's necessary to manually prepend the bos_token and append eos_token to the input, as has been established here: #3311 Setting pad_token = eos_token and running I solved the problem by first converting the attention_mask to boolean values, and then inverting the boolean attention_mask. Then |
Just to save the hassle for some folk from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print("EOS", tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
print("PAD", tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
string = "Hello World!"
string += tokenizer.eos_token # manually append eos since this is not done by GPT2Tokenizer
# string = tokenizer.bos_token + string # optionally prepend bos (which is actually the same as eos for GPT2Tokenizer)
tokenized = tokenizer(string, padding="max_length", max_length=10, return_tensors="pt")
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
print("INPUT_IDS BEFORE")
print(input_ids)
print("ATTENTION_MASK")
print(attention_mask)
input_ids[~attention_mask.bool()] = -100 # disable loss for padding tokens (i.e., eos tokens meant for padding)
print("INPUT_IDS AFTER")
print(input_ids) Result: EOS 50256
PAD 50256
INPUT_IDS BEFORE
tensor([[15496, 2159, 0, 50256, 50256, 50256, 50256, 50256, 50256, 50256]])
ATTENTION_MASK
tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])
INPUT_IDS AFTER
tensor([[15496, 2159, 0, 50256, -100, -100, -100, -100, -100, -100]]) |
If we use padding for short-sentence fine-tune data, when fine-tuning GPT2LMHeadModel, should we change the code here
transformers/src/transformers/modeling_gpt2.py
Line 744 in 48ff6d5
@patrickvonplaten @thomwolf
The text was updated successfully, but these errors were encountered: