Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When tranining the RWKV, it report "backward error" #31413

Open
2 of 4 tasks
lxianl455 opened this issue Jun 14, 2024 · 6 comments
Open
2 of 4 tasks

When tranining the RWKV, it report "backward error" #31413

lxianl455 opened this issue Jun 14, 2024 · 6 comments

Comments

@lxianl455
Copy link

lxianl455 commented Jun 14, 2024

System Info

  • transformers version: 4.41.2
  • Platform: Linux-4.14.105-1-tlinux3-0013-x86_64-with-glibc2.2.5
  • Python version: 3.8.12
  • Huggingface_hub version: 0.23.3
  • Safetensors version: 0.4.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

Using /root/.cache/torch_extensions/py38_cu118 as PyTorch extensions root...
Loading extension module wkv_20...
/usr/local/python/lib/python3.8/site-packages/torch/autograd/init.py:251: UserWarning: Error detected in RwkvLinearAttentionBackward. Traceback of forward call that caused the error:
File "/data1/rl_server/rl_learner/code//train.py", line 18, in
main()
File "/data1/rl_server/rl_learner/code//train.py", line 14, in main
trainer.run()
File "/usr/local/python/lib/python3.8/site-packages/sail/learner/init.py", line 14, in run
self.bench.run()
File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 236, in run
self._do_train()
File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 210, in _do_train
self.do_train_step(step_context, _input_datas)
File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 172, in do_train_step
outputs = self.net_wrapper(_input_datas, self.local_step)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dockerdata/rl_server/rl_learner/code/algorithm.py", line 338, in forward
each_hero_fc_result_list, all_lstm_state = self._inference(each_hero_data_list, lstm_initial_state, pos_lstm_initial_state)
File "/dockerdata/rl_server/rl_learner/code/algorithm.py", line 391, in _inference
lstm_outputs, lstm_state = self.public_lstm(reshape_new_fc_public_results, lstm_initial_state)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dockerdata/rl_server/rl_learner/code/algorithm.py", line 261, in forward
lstm_outputs, rwkv_state = self.lstm(reshape_new_fc_public_results, ) #暂时不传
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dockerdata/rl_server/rl_learner/code/components/custom_lstm_torch.py", line 573, in forward
hidden_states, state = block(
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dockerdata/rl_server/rl_learner/code/components/custom_lstm_torch.py", line 402, in forward
attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/dockerdata/rl_server/rl_learner/code/components/custom_lstm_torch.py", line 323, in forward
rwkv, layer_state = rwkv_linear_attention(
File "/dockerdata/rl_server/rl_learner/code/components/custom_lstm_torch.py", line 260, in rwkv_linear_attention
return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)
File "/usr/local/python/lib/python3.8/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/data1/rl_server/rl_learner/code//train.py", line 18, in
main()
File "/data1/rl_server/rl_learner/code//train.py", line 14, in main
trainer.run()
File "/usr/local/python/lib/python3.8/site-packages/sail/learner/init.py", line 14, in run
self.bench.run()
File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 236, in run
self._do_train()
File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 210, in _do_train
self.do_train_step(step_context, _input_datas)
File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 174, in do_train_step
total_loss.backward()
File "/usr/local/python/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/usr/local/python/lib/python3.8/site-packages/torch/autograd/init.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoTokenizer, RwkvForCausalLM
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.autograd.set_detect_anomaly(True)
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True)
tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(["Hello, my dog is cute","i like this"], return_tensors="pt",padding=True)
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
logits = outputs.logits
loss.backward()

Expected behavior

Success to backward.
Figure out why this happened and fix it.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker

@RUFFY-369
Copy link

Hi @lxianl455 please put your model in train mode with model.train() before performing a backprop

model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True)
tokenizer.pad_token = tokenizer.eos_token
model.train()
.
.
.

Cheers!

@lxianl455
Copy link
Author

Hi @lxianl455 please put your model in train mode with model.train() before performing a backprop

model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True)
tokenizer.pad_token = tokenizer.eos_token
model.train()
.
.
.

Cheers!

Actually, I want to combine RWKV block with some other modules to predict time series information. I am not using the whole RWKV model, but only its blocks. In this scenario, the Transformer Trainer cannot be used. How can I solve this backward error?

@RUFFY-369
Copy link

RUFFY-369 commented Jun 15, 2024

@lxianl455 okay so if you want to to use the RWKV model just for inference or in default eval() mode and don't want to put it in train mode then modify your code to this, the error will go away:

tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=False)
.
.
.

give use_cache a False value because during caching, the model stores intermediate results to speed up computation and it can interfere with gradient computation and also use_cache is not used for training or gradient computation.

Cheers!

@lxianl455
Copy link
Author

lxianl455 commented Jun 16, 2024

Actually, the function I want is to work like LSTM. When training, LSTM can take in the init state for initialization and can also return the ending state afterwards. https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html

I think the use_cache in the RWKV code is not like the use_cache in other models. Its actual function is like asking for the return of LSTM cell state and hidden state. Here is the code (line 300 in modeling_rwkv.py ):

rwkv, layer_state = rwkv_linear_attention( self.time_decay, self.time_first, key, value, state=layer_state, return_state=use_cache, )

@ArthurZucker
Copy link
Collaborator

The recurrentGemma model implements something more in the lines of RNN (so close to LSTM) if you are looking for an equivalent)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants