Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control flow issue with symbolic_trace when using inputs_embeds in MistralForCausalLM #31200

Closed
2 of 4 tasks
Hongjie1Chu opened this issue Jun 3, 2024 · 5 comments · Fixed by #31574
Closed
2 of 4 tasks

Comments

@Hongjie1Chu
Copy link

Hongjie1Chu commented Jun 3, 2024

System Info

  • transformers version: 4.41.2
  • Platform: Linux-5.15.0-88-generic-x86_64-with-glibc2.35
  • Python version: 3.10.6
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker @younesbelkada @Narsil

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

# from model_mistral import MistralForCausalLM
from transformers import MistralForCausalLM,MistralConfig
import torch
from transformers.utils.fx import symbolic_trace


cfg = MistralConfig.from_pretrained('mistralai/Mistral-7B-v0.1')
cfg.num_hidden_layers = 2
model = MistralForCausalLM(cfg)
print(model)
batch_size = 1
sequence_length = 10
hidden_dim = cfg.hidden_size

# create dummy_input_ids
dummy_input_ids = {
    'input_ids': torch.rand(batch_size, sequence_length)
}
traced = symbolic_trace(model,input_names=list(dummy_input_ids.keys()))

# create dummy_inputs_embeds
dummy_input_embeds = {
    'inputs_embeds': torch.rand(batch_size, sequence_length, hidden_dim)
}


traced = symbolic_trace(model,input_names=list(dummy_input_embeds.keys()))

when i use dummy_input_ids to trace ,no error,but when i use dummy_inputs_embeds to trace ,error:

Traceback (most recent call last):
  File "/root/chj/geesibling-pytorch_5.23/examples/pytorch/test.py", line 26, in <module>
    traced = symbolic_trace(model,input_names=list(dummy_input_embeds.keys()))
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/utils/fx.py", line 1483, in symbolic_trace
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/utils/fx.py", line 1306, in trace
    self.graph = super().trace(root, concrete_args=concrete_args)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1139, in forward
    outputs = self.model(
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 795, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/utils/fx.py", line 1170, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 479, in call_module
    ret_val = forward(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 788, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1024, in forward
    layer_outputs = decoder_layer(
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 795, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/utils/fx.py", line 1170, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 479, in call_module
    ret_val = forward(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 788, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 738, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 795, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/utils/fx.py", line 1170, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 479, in call_module
    ret_val = forward(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 788, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 650, in forward
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 795, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/utils/fx.py", line 1170, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 479, in call_module
    ret_val = forward(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 788, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 119, in forward
    if seq_len > self.max_seq_len_cached:
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/transformers/utils/fx.py", line 668, in __bool__
    return super().__bool__()
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 437, in __bool__
    return self.tracer.to_bool(self)
  File "/root/anaconda3/envs/gees_pytorch_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 300, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

Expected behavior

I would like to know why using input_ids allows for normal tracing, and there are no errors when executing if seq_len > self.max_seq_len_cached;, but switching to inputs_embeds causes an issue at this point?

"If I must trace using inputs_embeds, how should I proceed? I am working on a pipeline job and need to construct parts of the model as follows:

model1:

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        ),
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        ),
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
  )
)

model2:

MistralForCausalLM(
  (model): MistralModel(
    (layers): ModuleList(
      (0): MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        ),
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        ),
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    ),
    (norm): MistralRMSNorm()
  ),
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

So, model2 can only be traced using inputs_embeds. How should I modify it?"

@ArthurZucker
Copy link
Collaborator

cc @fxmarty and @michaelbenayoun

1 similar comment
@Hongjie1Chu
Copy link
Author

cc @fxmarty and @michaelbenayoun

@Hongjie1Chu
Copy link
Author

can u help me ? @fxmarty and @michaelbenayoun

Copy link

github-actions bot commented Jul 8, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@fxmarty
Copy link
Contributor

fxmarty commented Jul 8, 2024

Hi @Hongjie1Chu, this should be fixed on the main branch with #31574.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants