Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN in GPT NeoX model (generation) #17452

Closed
2 of 4 tasks
thies1006 opened this issue May 27, 2022 · 9 comments
Closed
2 of 4 tasks

NaN in GPT NeoX model (generation) #17452

thies1006 opened this issue May 27, 2022 · 9 comments
Labels

Comments

@thies1006
Copy link

System Info

- `transformers` version: 4.20.0.dev0
- Platform: Linux-4.15.0-140-generic-x86_64-with-glibc2.17
- Python version: 3.8.13
- Huggingface_hub version: 0.7.0
- PyTorch version (GPU?): 1.11.0+cu113 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@LysandreJik (NeoX)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Script to run:

import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM

weights_path = "EleutherAI/gpt-neox-20b"
config = AutoConfig.from_pretrained(weights_path)
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config)

tokenizer = AutoTokenizer.from_pretrained(weights_path)

device_map = infer_auto_device_map(model, no_split_module_classes=["GPTNeoXLayer"])
load_checkpoint_and_dispatch(
    model,
    weights_path,
    device_map=device_map,
    offload_folder=None,
    offload_state_dict=True
)

prompt = 'Huggingface is'
input_tokenized = tokenizer(prompt, return_tensors="pt")
output = model.generate(input_tokenized["input_ids"].to(0), do_sample=True)
output_text = tokenizer.decode(output[0].tolist())

Script is crashing with the traceback:

Traceback (most recent call last):
  File "run.py", line 24, in <module>
    output = model.generate(input_tokenized["input_ids"].to(0), do_sample=True)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/generation_utils.py", line 1316, in generate
    return self.sample(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/generation_utils.py", line 1934, in sample
    outputs = self(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/accelerate/hooks.py", line 148, in new_forward
    output = old_forward(*args, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 596, in forward
    outputs = self.gpt_neox(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 488, in forward
    outputs = layer(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/accelerate/hooks.py", line 148, in new_forward
    output = old_forward(*args, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 296, in forward
    attention_layer_outputs = self.attention(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 148, in forward
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 208, in _attn
    raise RuntimeError()
RuntimeError

The problem was also mentioned here:
#15642 (comment)
#15642 (comment)
The problem seems to be that torch.einsum returns inf (fp16 overflow) which leads to nan when calculating the softmax.

Expected behavior

Code should run.
@thies1006 thies1006 added the bug label May 27, 2022
@thies1006 thies1006 changed the title NaN in GPT NeoX model NaN in GPT NeoX model (generation) May 27, 2022
@LysandreJik
Copy link
Member

LysandreJik commented May 31, 2022

cc @zphang do you have an idea of what might be happening there?

Also happening here: huggingface/accelerate#404

@LysandreJik
Copy link
Member

Also cc @sgugger as it's leveraging the auto map.

@sgugger
Copy link
Collaborator

sgugger commented May 31, 2022

I don't think it comes from the auto map, some weights are Nan and so an error is raised here. (I did say that RuntimeError with no messages were not ideal @zphang ;-) )

@zomux
Copy link

zomux commented Jun 1, 2022

Seems to be a float16 overflow problem after applying torch.einsum, fixable with

  attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
+ finfo = torch.finfo(attn_scores.dtype)
+ attn_scores = attn_scores.clamp(finfo.min, finfo.max)

Although after fixing this NaN problem, the generation is still not working correctly.

@thies1006
Copy link
Author

Thanks @zomux, this works. I tried something similar (casting to FP32 and back) which gave the same error:

Traceback (most recent call last):
  File "tt.py", line 24, in <module>
    output = model.generate(input_tokenized["input_ids"].to(0), do_sample=True)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/generation_utils.py", line 1317, in generate
    return self.sample(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/generation_utils.py", line 1937, in sample
    outputs = self(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1112, in _call_impl
    return forward_call(*input, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/accelerate/hooks.py", line 150, in new_forward
    output = old_forward(*args, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 602, in forward
    outputs = self.gpt_neox(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1112, in _call_impl
    return forward_call(*input, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 493, in forward
    outputs = layer(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1112, in _call_impl
    return forward_call(*input, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/accelerate/hooks.py", line 150, in new_forward
    output = old_forward(*args, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 299, in forward
    attention_layer_outputs = self.attention(
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1112, in _call_impl
    return forward_call(*input, **kwargs)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 150, in forward
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  File "/secondary/thies/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 218, in _attn
    attn_output = torch.matmul(attn_weights, value)
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [64, 5] but got: [64, 1]

I figured out that when setting use_cache=False generation runs without errors.

Change:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L642

- return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past, "use_cache": False}

@zphang
Copy link
Contributor

zphang commented Jun 5, 2022

I don't think it comes from the auto map, some weights are Nan and so an error is raised here. (I did say that RuntimeError with no messages were not ideal @zphang ;-) )

Ah, I thought I'd removed all the RuntimeErrors, I've submitted a PR here: #17563

As for the NaN-ing, I've not figured that out either. I found (in very ad-hoc testing) that the einsum appears to be more stable than the original approach, but it looks like it hasn't fully solved the issue.

@thies1006
Copy link
Author

Hello @zphang,
the problem is that when the scaling factor is applied, the overflow has already happened.
Therefore I think the self.norm_factor should go into the matrix multiply (scale first and do the matrix multiply second):

- attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
+ attn_scores = torch.einsum("bik,bjk->bij", query / self.norm_factor, key)

It seems to work fine:

Huggingface is a fast-growing Chinese company that is the leader in deep learning and natural language processing.

The company raised $75million from Tencent and Baidu.

Deep Voice is a Chinese startup that enables users to control various home appliances using simple commands issued with their voice alone.

@hannan72
Copy link

hannan72 commented Jun 9, 2022

@zphang Is there a clear release version where the issue is resolved?
I got RuntimeError: probability tensor contains either inf, nan or element < 0 while deploying the gpt-neox model with half.
I can deploy it completely in full preciosion.

@github-actions
Copy link

github-actions bot commented Jul 3, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants