Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control flow issue with symbolic_trace when using inputs_embeds in LlamaForCausalLM #31414

Closed
2 of 4 tasks
Hongjie1Chu opened this issue Jun 14, 2024 · 5 comments · Fixed by #31574
Closed
2 of 4 tasks

Comments

@Hongjie1Chu
Copy link

System Info

  • transformers version: 4.41.2
  • Platform: Linux-5.15.0-88-generic-x86_64-with-glibc2.35
  • Python version: 3.10.6
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@fxmarty and @michaelbenayoun @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import LlamaForCausalLM,LlamaConfig
import torch
from transformers.utils.fx import symbolic_trace


cfg = LlamaConfig()
cfg.num_hidden_layers = 1
model = LlamaForCausalLM(cfg)
print(model)
traced = symbolic_trace(model,input_names=['input_ids'])
traced = symbolic_trace(model,input_names=['inputs_embeds'])

Expected behavior

I would like to know why using input_ids allows for normal tracing, and there are no errors when executing ' if query_states.device.type == "cuda" and causal_mask is not None: ', but switching to inputs_embeds causes an issue at this point?

the error is:


TraceError Traceback (most recent call last)
Cell In[7], line 1
----> 1 traced = symbolic_trace(model,input_names=['inputs_embeds'])

File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1483, in symbolic_trace(model, input_names, disable_check, tracer_cls)
1481 # Tracing.
1482 tracer = tracer_cls()
-> 1483 traced_graph = tracer.trace(model, concrete_args=concrete_args)
1484 traced = torch.fx.GraphModule(model, traced_graph)
1486 traced.config = model.config

File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1306, in HFTracer.trace(self, root, concrete_args, dummy_inputs, complete_concrete_args_with_inputs_not_in_dummy_inputs)
1304 with self.patch_for_tracing(root):
1305 try:
-> 1306 self.graph = super().trace(root, concrete_args=concrete_args)
1307 finally:
1308 _CURRENT_TRACER = None

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:817, in Tracer.trace(self, root, concrete_args)
810 for module in self._autowrap_search:
811 _autowrap_check(
812 patcher, module.dict, self._autowrap_function_ids
813 )
814 self.create_node(
815 "output",
816 "output",
--> 817 (self.create_arg(fn(*args)),),
818 {},
819 type_expr=fn.annotations.get("return", None),
820 )
822 self.submodule_paths = None
823 finally:

File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:1168, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1165 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1167 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1168 outputs = self.model(
1169 input_ids=input_ids,
1170 attention_mask=attention_mask,
1171 position_ids=position_ids,
1172 past_key_values=past_key_values,
1173 inputs_embeds=inputs_embeds,
1174 use_cache=use_cache,
1175 output_attentions=output_attentions,
1176 output_hidden_states=output_hidden_states,
1177 return_dict=return_dict,
1178 cache_position=cache_position,
1179 )
1181 hidden_states = outputs[0]
1182 if self.config.pretraining_tp > 1:

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:795, in Tracer.trace..module_call_wrapper(mod, *args, **kwargs)
788 return _orig_module_call(mod, *args, **kwargs)
790 _autowrap_check(
791 patcher,
792 getattr(getattr(mod, "forward", mod), "globals", {}),
793 self._autowrap_function_ids,
794 )
--> 795 return self.call_module(mod, forward, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1170, in HFTracer.call_module(self, m, forward, args, kwargs)
1168 return forward(*args, **kwargs)
1169 self.orig_forward = forward
-> 1170 return super().call_module(m, forward, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:479, in Tracer.call_module(self, m, forward, args, kwargs)
477 self.module_stack[_scope.module_path] = _scope.module_type
478 if not self.is_leaf_module(m, module_qualified_name):
--> 479 ret_val = forward(*args, **kwargs)
480 else:
481 ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:788, in Tracer.trace..module_call_wrapper..forward(*args, **kwargs)
787 def forward(*args, **kwargs):
--> 788 return _orig_module_call(mod, *args, **kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:969, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
958 layer_outputs = self._gradient_checkpointing_func(
959 decoder_layer.call,
960 hidden_states,
(...)
966 cache_position,
967 )
968 else:
--> 969 layer_outputs = decoder_layer(
970 hidden_states,
971 attention_mask=causal_mask,
972 position_ids=position_ids,
973 past_key_value=past_key_values,
974 output_attentions=output_attentions,
975 use_cache=use_cache,
976 cache_position=cache_position,
977 )
979 hidden_states = layer_outputs[0]
981 if use_cache:

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:795, in Tracer.trace..module_call_wrapper(mod, *args, **kwargs)
788 return _orig_module_call(mod, *args, **kwargs)
790 _autowrap_check(
791 patcher,
792 getattr(getattr(mod, "forward", mod), "globals", {}),
793 self._autowrap_function_ids,
794 )
--> 795 return self.call_module(mod, forward, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1170, in HFTracer.call_module(self, m, forward, args, kwargs)
1168 return forward(*args, **kwargs)
1169 self.orig_forward = forward
-> 1170 return super().call_module(m, forward, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:479, in Tracer.call_module(self, m, forward, args, kwargs)
477 self.module_stack[_scope.module_path] = _scope.module_type
478 if not self.is_leaf_module(m, module_qualified_name):
--> 479 ret_val = forward(*args, **kwargs)
480 else:
481 ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:788, in Tracer.trace..module_call_wrapper..forward(*args, **kwargs)
787 def forward(*args, **kwargs):
--> 788 return _orig_module_call(mod, *args, **kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:714, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
711 hidden_states = self.input_layernorm(hidden_states)
713 # Self Attention
--> 714 hidden_states, self_attn_weights, present_key_value = self.self_attn(
715 hidden_states=hidden_states,
716 attention_mask=attention_mask,
717 position_ids=position_ids,
718 past_key_value=past_key_value,
719 output_attentions=output_attentions,
720 use_cache=use_cache,
721 cache_position=cache_position,
722 )
723 hidden_states = residual + hidden_states
725 # Fully Connected

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:795, in Tracer.trace..module_call_wrapper(mod, *args, **kwargs)
788 return _orig_module_call(mod, *args, **kwargs)
790 _autowrap_check(
791 patcher,
792 getattr(getattr(mod, "forward", mod), "globals", {}),
793 self._autowrap_function_ids,
794 )
--> 795 return self.call_module(mod, forward, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1170, in HFTracer.call_module(self, m, forward, args, kwargs)
1168 return forward(*args, **kwargs)
1169 self.orig_forward = forward
-> 1170 return super().call_module(m, forward, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:479, in Tracer.call_module(self, m, forward, args, kwargs)
477 self.module_stack[_scope.module_path] = _scope.module_type
478 if not self.is_leaf_module(m, module_qualified_name):
--> 479 ret_val = forward(*args, **kwargs)
480 else:
481 ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:788, in Tracer.trace..module_call_wrapper..forward(*args, **kwargs)
787 def forward(*args, **kwargs):
--> 788 return _orig_module_call(mod, *args, **kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:641, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
636 causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
638 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
639 # Reference: pytorch/pytorch#112577.
--> 641 if query_states.device.type == "cuda" and causal_mask is not None:
642 query_states = query_states.contiguous()
643 key_states = key_states.contiguous()

File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:668, in HFProxy.bool(self)
666 if hasattr(self, "_metadata") and self._metadata is not None:
667 return self._metadata
--> 668 return super().bool()

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/proxy.py:437, in Proxy.bool(self)
434 self.tracer.create_proxy('call_function', assert_fn, (self,), {})
435 return True
--> 437 return self.tracer.to_bool(self)

File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/proxy.py:300, in TracerBase.to_bool(self, obj)
293 @compatibility(is_backward_compatible=True)
294 def to_bool(self, obj: 'Proxy') -> bool:
295 """Called when a proxy object is being converted to a boolean, such as
296 when used in control flow. Normally we don't know what to do because
297 we don't know the value of the proxy, but a custom tracer can attach more
298 information to the graph node using create_node and can choose to return a value.
299 """
--> 300 raise TraceError('symbolically traced variables cannot be used as inputs to control flow')

TraceError: symbolically traced variables cannot be used as inputs to control flow

@zucchini-nlp
Copy link
Member

I guess same as #31200

@Hongjie1Chu
Copy link
Author

yes ,the llama and mistral both cause this error,i submit two issue,but no reply ,so do you konw how to fix it? @zucchini-nlp

@zucchini-nlp
Copy link
Member

No, you tagged the right people in issue description. I just wanted to link related issue to each other

@songh11
Copy link

songh11 commented Jun 14, 2024

I had the same problem

@fxmarty
Copy link
Contributor

fxmarty commented Jul 8, 2024

Hi @Hongjie1Chu @songh11, this should be fixed on the main branch with #31574.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants