Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GemmaForCausalLM Causal Masking Not Working #30813

Closed
2 of 4 tasks
cmathw opened this issue May 14, 2024 · 4 comments
Closed
2 of 4 tasks

GemmaForCausalLM Causal Masking Not Working #30813

cmathw opened this issue May 14, 2024 · 4 comments

Comments

@cmathw
Copy link
Contributor

cmathw commented May 14, 2024

System Info

  • transformers version: 4.35.2
  • Platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35
  • Python version: 3.11.7
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.0
  • Accelerate version: 0.24.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Tensorflow version (GPU?): 2.15.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes.
  • Using distributed or parallel set-up in script?: No.

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

# Imports
from transformers import AutoTokenizer, AutoModelForCausalLM

text = "Hello World! This is a test string."

# Gemma
gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
gemma_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
gemma_tokens = gemma_tokenizer(text, return_tensors="pt")
gemma_output_dict = gemma_model(**gemma_tokens, output_attentions=True, output_hidden_states=True)
first_gemma_attn_pattern = gemma_output_dict['attentions'][0][0, 0, 0]
print(first_gemma_attn_pattern)

# GPT2
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2")
gpt2_tokens = gpt2_tokenizer(text, return_tensors="pt")
gpt2_output_dict = gpt2_model(**gpt2_tokens, output_attentions=True, output_hidden_states=True)
first_gpt2_attn_pattern = gpt2_output_dict['attentions'][0][0, 0, 0]
print(first_gpt2_attn_pattern)

first_gemma_attn_pattern outputs:

tensor([0.1393, 0.1916, 0.0398, 0.1050, 0.0786, 0.0850, 0.0610, 0.1118, 0.1076,
        0.0803], grad_fn=<SelectBackward0>)

first_gpt2_attn_pattern outputs:

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SelectBackward0>)

Expected behavior

I would expect that in both cases the first row of the attention pattern for each model is: tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=) due to causal masking. This does not seem to be the case for Gemma where causal masking doesn't appear to be applied.

@marthaflinderslewis
Copy link

I also have this problem for Llama, MWE below:

from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

model_inputs = tokenizer.encode("Cats chase dogs", return_tensors="pt").to("cuda:0")

output = model.generate(model_inputs, output_attentions=True, max_new_tokens=5, return_dict_in_generate=True)

print(output.attentions[0][0][0][0]) # prints attention weights of first head in first layer

output is:

tensor([[0.3464, 0.1967, 0.0438, 0.1245, 0.0281, 0.2606],
[0.3714, 0.3745, 0.0209, 0.1465, 0.0157, 0.0710],
[0.1448, 0.4541, 0.0274, 0.2500, 0.0149, 0.1087],
[0.2242, 0.3160, 0.0371, 0.2656, 0.0157, 0.1414],
[0.1242, 0.2456, 0.0464, 0.3104, 0.0118, 0.2615],
[0.1509, 0.1745, 0.0558, 0.2522, 0.0131, 0.3535]], device='cuda:0')

and I would have expected this to be triangular.

https://discuss.huggingface.co/t/why-are-llama2-attention-weights-not-lower-triangular/85066

@ArthurZucker
Copy link
Collaborator

Hey! This was fixed by #30652 !

@ArthurZucker
Copy link
Collaborator

Use attn_implementation="eager" when you do from_pretrained

@cmathw
Copy link
Contributor Author

cmathw commented May 15, 2024

Awesome thank you!

@cmathw cmathw closed this as completed May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants