Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss mask for fine-tuning GPT2LMHeadModel model #7135

Closed
zhujl1991 opened this issue Sep 15, 2020 · 5 comments
Closed

Loss mask for fine-tuning GPT2LMHeadModel model #7135

zhujl1991 opened this issue Sep 15, 2020 · 5 comments

Comments

@zhujl1991
Copy link

zhujl1991 commented Sep 15, 2020

If we use padding for short-sentence fine-tune data, when fine-tuning GPT2LMHeadModel, should we change the code here

loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
to exclude the loss for padding tokens?

@patrickvonplaten @thomwolf

@zhujl1991
Copy link
Author

It has already been mentioned here #2001 (see "Bug: Padded tokens are not excluded from the loss" session).
Any plan to fix this?

@patil-suraj
Copy link
Contributor

Hi GPT-2 has no pad token so you can either introduce new pad token or set the eos toke as pad token
tokenizer.pad_token_id = tokenizer.eos_token_id

and then set the pad tokens in labels to -100 which is the default ignore index for CrossEntropyLoss
labels[labels == self.tokenizer.pad_token_id] = -100

@zhujl1991
Copy link
Author

Hi GPT-2 has no pad token so you can either introduce new pad token or set the eos toke as pad token
tokenizer.pad_token_id = tokenizer.eos_token_id

and then set the pad tokens in labels to -100 which is the default ignore index for CrossEntropyLoss
labels[labels == self.tokenizer.pad_token_id] = -100

Thanks. Just get aware of the ignore_index parameter https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

@ttanida
Copy link

ttanida commented Jul 2, 2022

For fine-tuning the GPT2 model, it's necessary to manually prepend the bos_token and append eos_token to the input, as has been established here: #3311

Setting pad_token = eos_token and running labels[labels == pad_token_id] = -100 would therefore be a problem in my opinion, since we would not only ignore padding tokens, but also eos_tokens at the end of sentences for loss computation.

I solved the problem by first converting the attention_mask to boolean values, and then inverting the boolean attention_mask. Then labels[inv_bool_attention_mask] = -100, such that padding tokens are ignored, but no eos_tokens.

@itsnamgyu
Copy link

Just to save the hassle for some folk

from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

tokenizer.pad_token = tokenizer.eos_token

print("EOS", tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
print("PAD", tokenizer.convert_tokens_to_ids(tokenizer.pad_token))

string = "Hello World!"
string += tokenizer.eos_token  # manually append eos since this is not done by GPT2Tokenizer
# string = tokenizer.bos_token + string  # optionally prepend bos (which is actually the same as eos for GPT2Tokenizer)

tokenized = tokenizer(string, padding="max_length", max_length=10, return_tensors="pt")
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]

print("INPUT_IDS BEFORE")
print(input_ids)
print("ATTENTION_MASK")
print(attention_mask)

input_ids[~attention_mask.bool()] = -100  # disable loss for padding tokens (i.e., eos tokens meant for padding)

print("INPUT_IDS AFTER")
print(input_ids)

Result:

EOS 50256
PAD 50256
INPUT_IDS BEFORE
tensor([[15496,  2159,     0, 50256, 50256, 50256, 50256, 50256, 50256, 50256]])
ATTENTION_MASK
tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])
INPUT_IDS AFTER
tensor([[15496,  2159,     0, 50256,  -100,  -100,  -100,  -100,  -100,  -100]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants