Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Illegal memory access CUDA error when using long sequences #2062

Closed
tomeras91 opened this issue Jun 28, 2022 · 9 comments
Closed

[BUG] Illegal memory access CUDA error when using long sequences #2062

tomeras91 opened this issue Jun 28, 2022 · 9 comments
Assignees
Labels
bug Something isn't working inference

Comments

@tomeras91
Copy link

Describe the bug
Running a forward pass on a DeepSpeedTransformerInference layer, with a sequence length of ~1000 tokens, results in an illegal memory access CUDA error.

To Reproduce
Here is a minimal reproducible example that shows the bug:

from deepspeed.ops.transformer import DeepSpeedInferenceConfig, DeepSpeedTransformerInference
import torch

torch.cuda.set_device(0)

hidden_size = 256
heads = 8
num_layers = 12
fp16 = True
layernorm_epsilon = 1e-5
deepspeed_config = DeepSpeedInferenceConfig(hidden_size=hidden_size,
                                            intermediate_size=hidden_size * 4,
                                            heads=heads,
                                            num_hidden_layers=num_layers,
                                            layer_norm_eps=layernorm_epsilon,
                                            # encoder_decoder=False,
                                            fp16=fp16,
                                            pre_layer_norm=True,
                                            stochastic_mode=False,
                                            scale_attention=True,
                                            triangular_masking=True,
                                            local_attention=False,
                                            window_size=256,
                                            )
transformer = DeepSpeedTransformerInference(config=deepspeed_config)
transformer.half()
new_state_dict = {k: 0.01*torch.ones(*v.shape, dtype=v.dtype, device=v.device)
                  for k,v in transformer.state_dict().items()}
transformer.load_state_dict(new_state_dict)
transformer.cuda()
device = list(transformer.parameters())[0].device

batch_size = 1
seq_len = 1000
inputs = torch.ones((batch_size, seq_len, hidden_size), dtype=torch.float16, device=device)
input_mask = torch.ones(*inputs.shape[:2], dtype=bool, device=device)

output, _ = transformer(
    input=inputs,
    input_mask=input_mask)

print(f"outupt: \n {output}")

Running the code resulted with the following exception

RuntimeError: CUDA error: an illegal memory access was encountered

Expected behavior
I was expecting to get a correct output, without the excpetion.

ds_report output

[2022-06-28 10:35:33,425] [WARNING] [partition_parameters.py:60:<module>] unable to find torch.distributed._all_gather_base. will fall back to torch.distributed.all_gather which will result in suboptimal performance. please consider upgrading your pytorch installation.
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
 [WARNING]  using untested triton version (1.1.1), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.8/site-packages/torch']
torch version .................... 1.8.0a0+1606899
torch cuda version ............... 11.1
torch hip version ................ None
nvcc version ..................... 11.1
deepspeed install path ........... ['/opt/conda/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.6.5, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.8, cuda 11.1

System info (please complete the following information):

  • OS: Ubuntu 20.04
  • GPU count and types: a single A100 GPU
  • Python version: 3.8.5

Launcher context
Launching directly using Python interpreter.

Additional context
Maybe the bug is related to line 20 in csrc/transformer/inference/includes/custom_cuda_layers.h? It reads:

#define MAX_OUT_TOKES 1024
@tomeras91 tomeras91 added the bug Something isn't working label Jun 28, 2022
@mrwyattii
Copy link
Contributor

@tomeras91 I can confirm that I'm able to reproduce this error. I don't think it has anything to do with MAX_OUT_TOKES. @RezaYazdaniAminabadi could you take a look at this?

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @tomeras91

Thanks for reporting this issue. I will look into this.
@mrwyattii, thanks for reproducing this. Yes, I think the issue is probably somewhere else.
Thanks,
Reza

@cmikeh2 cmikeh2 self-assigned this Jul 30, 2022
@trianxy
Copy link

trianxy commented Aug 19, 2022

Below is a possibly related bug. I added some sample code to reproduce this error for a GPT2 model on an NVidia A10G. Let me know @RezaYazdaniAminabadi @cmikeh2 if you think I should rather file a new issue.

Describe the bug

After initialising a GPT2 model from Huggingface with DeepSpeed, I can run inference on short sequences. But when using long sequences with e.g. 700 tokens, I get multiple warnings !!!! kernel execution error. (m: 700, n: 700, k: 64, error: 14) which culminate in a RuntimeError: CUDA error: an illegal memory access was encountered. I can thus not run inference for this model.

Important: I tested old versions and found that I do not encounter that problem for deepspeed versions <= 0.6.6

To Reproduce
Steps to reproduce the behavior:

  1. Install packages
pip install --upgrade pip
pip uninstall -y torch deepspeed transformers
pip install --upgrade torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --upgrade deepspeed==0.7.0 transformers==4.21.1
  1. Run
import os
import deepspeed
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

model = AutoModelForCausalLM.from_pretrained("gpt2").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = deepspeed.init_inference(model, replace_method='auto', replace_with_kernel_inject=True)

short_input = tokenizer("Hello, my dog is cute." * 1, return_tensors="pt").to("cuda")
long_input = tokenizer("Hello, my dog is cute." * 100, return_tensors="pt").to("cuda")
outputs = model(**short_input)  # this works fine
outputs = model(**long_input)  # this throws below error
  1. See error
!!!! kernel execution error. (batch: 12, m: 700, n: 700, k: 64, error: 14) 
!!!! kernel execution error. (batch: 12, m: 64, n: 700, k: 700, error: 14) 
!!!! kernel execution error. (m: 768, n: 700, k: 768, error: 14) 
!!!! kernel execution error. (m: 3072, n: 700, k: 768, error: 14) 
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_21875/4257510811.py in <cell line: 1>()
----> 1 outputs = model(**long_input)
      2 token_id = torch.argmax(outputs.logits.squeeze()[-1]).item()
      3 print(tokenizer.decode(token_id), outputs.logits.mean())

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/deepspeed/inference/engine.py in forward(self, *inputs, **kwargs)
    528                     outputs = self._graph_replay(*inputs, **kwargs)
    529             else:
--> 530                 outputs = self.module(*inputs, **kwargs)
    531             #outputs = self.module(*inputs, **kwargs)
    532         return outputs

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1146             input = bw_hook.setup_input_hook(input)
   1147 
-> 1148         result = forward_call(*input, **kwargs)
   1149         if _global_forward_hooks or self._forward_hooks:
   1150             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py in forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1056         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1057 
-> 1058         transformer_outputs = self.transformer(
   1059             input_ids,
   1060             past_key_values=past_key_values,

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py in forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)
    899                 )
    900             else:
--> 901                 outputs = block(
    902                     hidden_states,
    903                     layer_past=layer_past,

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py in forward(self, input, input_mask, attention_mask, head_mask, layer_past, get_key_value, get_present, encoder_output, enc_dec_attn_mask, encoder_hidden_states, encoder_attention_mask, use_cache, alibi, output_attentions)
    840             presents = (key, value)
    841             self.layer_past = presents if layer_past is None else None
--> 842             output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
    843 
    844             if not self.config.pre_layer_norm:

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py in forward(self, input, residual, residual_norm, bias)
    710 
    711     def forward(self, input, residual, residual_norm, bias):
--> 712         return DeepSpeedMLPFunction.apply(input,
    713                                           residual,
    714                                           residual_norm,

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py in forward(ctx, input, residual, residual_norm, bias, inter_w, inter_b, attn_nw, attn_nb, config, mp_group, output_b, output_w, q_scales, q_groups, merge_count, mlp_gemm_func, fused_gemm_gelu, vector_matmul_func, bias_residual_func)
    631                                              config.pre_layer_norm,
    632                                              config.mlp_after_attn)
--> 633                 output = vector_matmul_func(intermediate, output_w, False)
    634 
    635         inference_cuda_module.residual_add(

RuntimeError: CUDA error: an illegal memory access was encountered

Expected behavior
I did not expect an error to be thrown, but rather the variable outputs to be filled

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch']
torch version .................... 1.12.1+cu116
torch cuda version ............... 11.6
torch hip version ................ None
nvcc version ..................... 11.1
deepspeed install path ........... ['/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.7.0, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.6

System info (please complete the following information):

  • OS: Debian (also Amazon Linux 2)
  • GPU count and types: 1 GPU A10G (AWS g5.xlarge notebook instance)
  • Python version: 3.8.12 (also with 3.9.12)

Launcher context
Inside a Python notebook

Possibly related issues:

@mallorbc
Copy link

mallorbc commented Sep 6, 2022

I also have encountered this error. Trying small inputs such as what the tutorial uses "DeepSpeed is" leads to normal results, but using significantly longer input leads to an illegal memory error.

I would try version 0.6.6 or earlier as suggested by @trianxy but I want to use long sequences with GPTJ and GPT Neo 2.7B and those had issues up until recently as can be seen in #2233 .

My build is the same as the one in that issue, just with DeepSpeed built from source shortly after the PR that fixed the issue.

@trianxy
Copy link

trianxy commented Sep 14, 2022

FYI @mallorbc , @tomeras91 , @RezaYazdaniAminabadi :

My related issue which I detailed above is fixed in this PR. More precisely, my issue does not appear when I install the commit 4abd455

Thanks for that fix!

@mallorbc
Copy link

FYI @mallorbc , @tomeras91 , @RezaYazdaniAminabadi :

My related issue which I detailed above is fixed in this PR. More precisely, my issue does not appear when I install the commit 4abd455

Thanks for that fix!

If I recall, I also tried building from that PR and had issues with poor outputs for GPT Neo and GPTJ. I believe one of the branches I built fixed the memory error but still gave garbage output for long inputs. Perhaps this is the one.

Perhaps I am remembering wrong though, I will try this again later and see if it fixed anything, but again I think I tried this already.

Thanks!

@trianxy
Copy link

trianxy commented Sep 16, 2022

Alright @mallorbc - let me know if you need any support with testing.

It's true that I have seen inconsistent behavior when trying different GPT architectures with different inputs, so it may be that not all cases have been fixed by mentioned PR. I did not run a lot of test cases with different architectures and inputs.

@EdouardVilain-Git
Copy link

EdouardVilain-Git commented Oct 7, 2022

Hi everyone,
I am encountering the same issues with a RoBERTa type model on which I ran 8 bit MoQ during training.
When instantiating the inference engine using:
engine = deepspeed.init_inference(deepspeed_trainer.model, dtype=torch.int8, quantization_setting=(False,64), replace_with_kernel_inject=True )
I get the same error than the one mentioned by @tomeras91 and @trianxy. Do you know when this issue may be fixed?
This also raises a question I had: is it possible to use Deepspeed for training (typically for QAT) and infer using torch?
I would like to use the quantized weights of my model trained using Deepspeed to create a quantized torch instance for inference. It is still unclear if/how I could do so.

Thank you for your help!

@cmikeh2 cmikeh2 closed this as completed Nov 4, 2022
@carlose2108
Copy link

Hi @trianxy!
Could you solve this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

No branches or pull requests

9 participants