You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
from transformers import LlamaForCausalLM,LlamaConfig
import torch
from transformers.utils.fx import symbolic_trace
cfg = LlamaConfig()
cfg.num_hidden_layers = 1
model = LlamaForCausalLM(cfg)
print(model)
traced = symbolic_trace(model,input_names=['input_ids'])
traced = symbolic_trace(model,input_names=['inputs_embeds'])
Expected behavior
I would like to know why using input_ids allows for normal tracing, and there are no errors when executing ' if query_states.device.type == "cuda" and causal_mask is not None: ', but switching to inputs_embeds causes an issue at this point?
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs) 1529 try: 1530 result = None
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs) 1529 try: 1530 result = None
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs) 1529 try: 1530 result = None
File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:641, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position) 636 causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] 638 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, 639 # Reference: pytorch/pytorch#112577.
--> 641 if query_states.device.type == "cuda" and causal_mask is not None: 642 query_states = query_states.contiguous() 643 key_states = key_states.contiguous()
File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:668, in HFProxy.bool(self) 666 if hasattr(self, "_metadata") and self._metadata is not None: 667 return self._metadata
--> 668 return super().bool()
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/proxy.py:300, in TracerBase.to_bool(self, obj) 293@compatibility(is_backward_compatible=True) 294 def to_bool(self, obj: 'Proxy') -> bool: 295 """Called when a proxy object is being converted to a boolean, such as 296 when used in control flow. Normally we don't know what to do because 297 we don't know the value of the proxy, but a custom tracer can attach more 298 information to the graph node using create_node and can choose to return a value. 299 """
--> 300 raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
TraceError: symbolically traced variables cannot be used as inputs to control flow
The text was updated successfully, but these errors were encountered:
System Info
transformers
version: 4.41.2Who can help?
@fxmarty and @michaelbenayoun @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
I would like to know why using input_ids allows for normal tracing, and there are no errors when executing ' if query_states.device.type == "cuda" and causal_mask is not None: ', but switching to inputs_embeds causes an issue at this point?
the error is:
TraceError Traceback (most recent call last)
Cell In[7], line 1
----> 1 traced = symbolic_trace(model,input_names=['inputs_embeds'])
File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1483, in symbolic_trace(model, input_names, disable_check, tracer_cls)
1481 # Tracing.
1482 tracer = tracer_cls()
-> 1483 traced_graph = tracer.trace(model, concrete_args=concrete_args)
1484 traced = torch.fx.GraphModule(model, traced_graph)
1486 traced.config = model.config
File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1306, in HFTracer.trace(self, root, concrete_args, dummy_inputs, complete_concrete_args_with_inputs_not_in_dummy_inputs)
1304 with self.patch_for_tracing(root):
1305 try:
-> 1306 self.graph = super().trace(root, concrete_args=concrete_args)
1307 finally:
1308 _CURRENT_TRACER = None
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:817, in Tracer.trace(self, root, concrete_args)
810 for module in self._autowrap_search:
811 _autowrap_check(
812 patcher, module.dict, self._autowrap_function_ids
813 )
814 self.create_node(
815 "output",
816 "output",
--> 817 (self.create_arg(fn(*args)),),
818 {},
819 type_expr=fn.annotations.get("return", None),
820 )
822 self.submodule_paths = None
823 finally:
File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:1168, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1165 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1167 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1168 outputs = self.model(
1169 input_ids=input_ids,
1170 attention_mask=attention_mask,
1171 position_ids=position_ids,
1172 past_key_values=past_key_values,
1173 inputs_embeds=inputs_embeds,
1174 use_cache=use_cache,
1175 output_attentions=output_attentions,
1176 output_hidden_states=output_hidden_states,
1177 return_dict=return_dict,
1178 cache_position=cache_position,
1179 )
1181 hidden_states = outputs[0]
1182 if self.config.pretraining_tp > 1:
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:795, in Tracer.trace..module_call_wrapper(mod, *args, **kwargs)
788 return _orig_module_call(mod, *args, **kwargs)
790 _autowrap_check(
791 patcher,
792 getattr(getattr(mod, "forward", mod), "globals", {}),
793 self._autowrap_function_ids,
794 )
--> 795 return self.call_module(mod, forward, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1170, in HFTracer.call_module(self, m, forward, args, kwargs)
1168 return forward(*args, **kwargs)
1169 self.orig_forward = forward
-> 1170 return super().call_module(m, forward, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:479, in Tracer.call_module(self, m, forward, args, kwargs)
477 self.module_stack[_scope.module_path] = _scope.module_type
478 if not self.is_leaf_module(m, module_qualified_name):
--> 479 ret_val = forward(*args, **kwargs)
480 else:
481 ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:788, in Tracer.trace..module_call_wrapper..forward(*args, **kwargs)
787 def forward(*args, **kwargs):
--> 788 return _orig_module_call(mod, *args, **kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:969, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
958 layer_outputs = self._gradient_checkpointing_func(
959 decoder_layer.call,
960 hidden_states,
(...)
966 cache_position,
967 )
968 else:
--> 969 layer_outputs = decoder_layer(
970 hidden_states,
971 attention_mask=causal_mask,
972 position_ids=position_ids,
973 past_key_value=past_key_values,
974 output_attentions=output_attentions,
975 use_cache=use_cache,
976 cache_position=cache_position,
977 )
979 hidden_states = layer_outputs[0]
981 if use_cache:
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:795, in Tracer.trace..module_call_wrapper(mod, *args, **kwargs)
788 return _orig_module_call(mod, *args, **kwargs)
790 _autowrap_check(
791 patcher,
792 getattr(getattr(mod, "forward", mod), "globals", {}),
793 self._autowrap_function_ids,
794 )
--> 795 return self.call_module(mod, forward, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1170, in HFTracer.call_module(self, m, forward, args, kwargs)
1168 return forward(*args, **kwargs)
1169 self.orig_forward = forward
-> 1170 return super().call_module(m, forward, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:479, in Tracer.call_module(self, m, forward, args, kwargs)
477 self.module_stack[_scope.module_path] = _scope.module_type
478 if not self.is_leaf_module(m, module_qualified_name):
--> 479 ret_val = forward(*args, **kwargs)
480 else:
481 ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:788, in Tracer.trace..module_call_wrapper..forward(*args, **kwargs)
787 def forward(*args, **kwargs):
--> 788 return _orig_module_call(mod, *args, **kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:714, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
711 hidden_states = self.input_layernorm(hidden_states)
713 # Self Attention
--> 714 hidden_states, self_attn_weights, present_key_value = self.self_attn(
715 hidden_states=hidden_states,
716 attention_mask=attention_mask,
717 position_ids=position_ids,
718 past_key_value=past_key_value,
719 output_attentions=output_attentions,
720 use_cache=use_cache,
721 cache_position=cache_position,
722 )
723 hidden_states = residual + hidden_states
725 # Fully Connected
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:795, in Tracer.trace..module_call_wrapper(mod, *args, **kwargs)
788 return _orig_module_call(mod, *args, **kwargs)
790 _autowrap_check(
791 patcher,
792 getattr(getattr(mod, "forward", mod), "globals", {}),
793 self._autowrap_function_ids,
794 )
--> 795 return self.call_module(mod, forward, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:1170, in HFTracer.call_module(self, m, forward, args, kwargs)
1168 return forward(*args, **kwargs)
1169 self.orig_forward = forward
-> 1170 return super().call_module(m, forward, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:479, in Tracer.call_module(self, m, forward, args, kwargs)
477 self.module_stack[_scope.module_path] = _scope.module_type
478 if not self.is_leaf_module(m, module_qualified_name):
--> 479 ret_val = forward(*args, **kwargs)
480 else:
481 ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/_symbolic_trace.py:788, in Tracer.trace..module_call_wrapper..forward(*args, **kwargs)
787 def forward(*args, **kwargs):
--> 788 return _orig_module_call(mod, *args, **kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py:641, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
636 causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
638 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
639 # Reference: pytorch/pytorch#112577.
--> 641 if query_states.device.type == "cuda" and causal_mask is not None:
642 query_states = query_states.contiguous()
643 key_states = key_states.contiguous()
File ~/Library/Python/3.9/lib/python/site-packages/transformers/utils/fx.py:668, in HFProxy.bool(self)
666 if hasattr(self, "_metadata") and self._metadata is not None:
667 return self._metadata
--> 668 return super().bool()
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/proxy.py:437, in Proxy.bool(self)
434 self.tracer.create_proxy('call_function', assert_fn, (self,), {})
435 return True
--> 437 return self.tracer.to_bool(self)
File ~/Library/Python/3.9/lib/python/site-packages/torch/fx/proxy.py:300, in TracerBase.to_bool(self, obj)
293 @compatibility(is_backward_compatible=True)
294 def to_bool(self, obj: 'Proxy') -> bool:
295 """Called when a proxy object is being converted to a boolean, such as
296 when used in control flow. Normally we don't know what to do because
297 we don't know the value of the proxy, but a custom tracer can attach more
298 information to the graph node using create_node and can choose to return a value.
299 """
--> 300 raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
TraceError: symbolically traced variables cannot be used as inputs to control flow
The text was updated successfully, but these errors were encountered: