-
Notifications
You must be signed in to change notification settings - Fork 30.6k
Closed
Labels
Description
System Info
Transformers main branch (commit 0f1b128 )
transformers
version: 4.57.0.dev0- Platform: Linux-5.15.0-1030-nvidia-x86_64-with-glibc2.39
- Python version: 3.12.3
- Huggingface_hub version: 0.34.4
- Safetensors version: 0.5.3
- Accelerate version: 1.10.1
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0a0+5228986c39.nv25.06 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: tensor-parallel
- Using GPU in script?: yes
- GPU type: NVIDIA H100 80GB HBM3
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Tunning TP inference on bigcode/starcoder2-7b
throws an error with incorrect tensor shapes due to base_model_tp_plan
misconfiguration.
demo.py
:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "bigcode/starcoder2-7b"
model = AutoModelForCausalLM.from_pretrained(model_id, tp_plan="auto")
model._tp_plan['model.layers.*.mlp.c_proj'] = 'rowwise'
print(f"TP plan: {model._tp_plan}, class: {type(model._tp_plan)}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
# distributed run
outputs = model(inputs)
# print the output
print(outputs)
run with
torchrun --nproc_per_node=2 demo.py
The correct base_model_tp_plan
should replace:
['model.layers.*.mlp.c_proj'] = 'colwise'
with
['model.layers.*.mlp.c_proj'] = 'rowwise'
Expected behavior
Throws:
(...)
[rank0]: File "/lustre/fs1/portfolios/coreai/users/gkwasniewski/hf-repo/transformers/src/transformers/models/starcoder2/modeling_starcoder2.py", line 65, in forward
[rank0]: hidden_states = self.c_proj(hidden_states)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1857, in _call_impl
[rank0]: return inner()
[rank0]: ^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1805, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py", line 125, in forward
[rank0]: return F.linear(input, self.weight, self.bias)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 51, in inner
[rank0]: return disable_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 850, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_dispatch.py", line 160, in dispatch
[rank0]: self.sharding_propagator.propagate(op_info)
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 266, in propagate
[rank0]: OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__
[rank0]: return self.cache(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 279, in propagate_op_sharding_non_cached
[rank0]: out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 126, in _propagate_tensor_meta_non_cached
[rank0]: fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in __call__
[rank0]: return self._op(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 27, in wrapper
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1311, in __torch_dispatch__
[rank0]: return self.dispatch(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1932, in dispatch
[rank0]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1414, in _cached_dispatch_impl
[rank0]: output = self._dispatch_impl(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2460, in _dispatch_impl
[rank0]: decomposition_table[func](*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/wrappers.py", line 308, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_decomp/decompositions.py", line 84, in inner
[rank0]: r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_decomp/decompositions.py", line 1451, in addmm
[rank0]: out = alpha * torch.mm(mat1, mat2)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 27, in wrapper
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1311, in __torch_dispatch__
[rank0]: return self.dispatch(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1932, in dispatch
[rank0]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1414, in _cached_dispatch_impl
[rank0]: output = self._dispatch_impl(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2554, in _dispatch_impl
[rank0]: r = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in __call__
[rank0]: return self._op(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/wrappers.py", line 308, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_meta_registrations.py", line 2236, in meta_mm
[rank0]: torch._check(
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1668, in _check
[rank0]: _check_with(RuntimeError, cond, message)
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1650, in _check_with
[rank0]: raise error_type(message_evaluated)
[rank0]: RuntimeError: a and b must have same reduction dim, but got [3, 9216] X [18432, 4608].
(...)
Expected output:
CausalLMOutputWithPast(loss=None, logits=tensor([[[ 0.6951, -2.9710, -12.8470, ..., -4.8511, -6.0277, -6.6027],
[ 2.4489, -0.3970, -1.9423, ..., -3.9063, -5.0727, -5.9155],
[ 4.5938, -0.8972, -1.5770, ..., -4.8748, -2.2605, -5.4515]]],