Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss.backward() producing nan values with 8-bit Llama-3-70B-Instruct #30526

Open
2 of 4 tasks
haroldtimmers opened this issue Apr 28, 2024 · 2 comments
Open
2 of 4 tasks

Comments

@haroldtimmers
Copy link

System Info

  • transformers version: 4.40.1
  • Platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.31
  • Python version: 3.11.7
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.26.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True (NVIDIA A100)
  • Using distributed or parallel set-up in script?: False
  • bitsandbytes version: 0.43.1

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
torch.autograd.set_detect_anomaly(True)

hf_token = ""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct", trust_remote_code=True, token=hf_token)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct",
                                             trust_remote_code=True, 
                                             token=hf_token,
                                             device_map="auto",
                                             max_memory={0: '80GIB'}, 
                                             quantization_config=bnb_config,
                                             low_cpu_mem_usage=True)

inputs = torch.tensor([[128000, 128006,   9125, 128007,    271,  16533,    279,   2768,   3488,
            449,   1193,    264,   3254,  12360,    596,    836,    323,    912,
           5217,   1495,     13, 128009, 128006,    882, 128007,    271,    678,
            459,  12360,    304,    279,   5818,  19574,   1369,   1147,    320,
           2550,     15,    570,  22559,    449,   1193,    264,   3254,  12360,
            596,    836,    323,    912,   5217,   1495,     13, 128009, 128006,
          78191, 128007,    271,   4873,     75,    783,    473,    478]]).long().to(model.device)

logits = model(inputs)['logits']
probs = torch.nn.functional.softmax(logits, dim=-1)
loss = -torch.log(probs[0, -1, 263])
loss.backward()

This produces the following error message and stack trace:

/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/graph.py:744: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
    self._run_once()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
    handle._run()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 542, in dispatch_queue
    await self.process_one()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 531, in process_one
    await dispatch(*args)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 359, in execute_request
    await super().execute_request(stream, ident, parent)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 775, in execute_request
    reply_content = await reply_content
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 446, in do_execute
    res = shell.run_cell(
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell
    result = self._run_cell(
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell
    result = runner(coro)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_1296557/2560056571.py", line 34, in <module>
    logits = model(inputs)['logits']
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1208, in forward
    outputs = self.model(
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1018, in forward
    layer_outputs = decoder_layer(
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 738, in forward
    hidden_states = self.input_layernorm(hidden_states)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 89, in forward
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 38
     36 print(probs[0, -1, 263])
     37 loss = -torch.log(probs[0, -1, 263])
---> 38 loss.backward()

File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
   (...)
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.

Expected behavior

No error thrown.

@haroldtimmers haroldtimmers changed the title loss.backward() producing nan values with 8b LLama-3-70B-Instruct loss.backward() producing nan values with 8-bit LLama-3-70B-Instruct Apr 28, 2024
@haroldtimmers haroldtimmers changed the title loss.backward() producing nan values with 8-bit LLama-3-70B-Instruct loss.backward() producing nan values with 8-bit Llama-3-70B-Instruct Apr 28, 2024
@ArthurZucker
Copy link
Collaborator

Hey! I think this is related to the model itself which did not really trained these tokens properly? Saw a few threads saying the init of these tokens should be updated. @younesbelkada if you can have a look if it's rather quantization related!

@younesbelkada
Copy link
Contributor

hi @haroldtimmers !
Thanks for the issue ! This is expected, you can't purely train a quantized model like that, you need to either attach new trainable adapters such as PEFT modules: https://huggingface.co/docs/peft/index (e.g. LoRA), or make only non-8bit layers (such as the LM Head) trainable and freeze the rest of the model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants