Skip to content

Incorrect sharding configuration for Starcoder2 model #40813

@greg-kwasniewski1

Description

@greg-kwasniewski1

System Info

Transformers main branch (commit 0f1b128 )

  • transformers version: 4.57.0.dev0
  • Platform: Linux-5.15.0-1030-nvidia-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • Huggingface_hub version: 0.34.4
  • Safetensors version: 0.5.3
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0a0+5228986c39.nv25.06 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: tensor-parallel
  • Using GPU in script?: yes
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Tunning TP inference on bigcode/starcoder2-7b throws an error with incorrect tensor shapes due to base_model_tp_plan misconfiguration.

demo.py:

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "bigcode/starcoder2-7b"
model = AutoModelForCausalLM.from_pretrained(model_id, tp_plan="auto")

model._tp_plan['model.layers.*.mlp.c_proj'] = 'rowwise'
print(f"TP plan: {model._tp_plan}, class: {type(model._tp_plan)}")

tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

# distributed run
outputs = model(inputs)

# print the output
print(outputs)

run with

torchrun --nproc_per_node=2 demo.py

The correct base_model_tp_plan should replace:

['model.layers.*.mlp.c_proj'] = 'colwise'

with

['model.layers.*.mlp.c_proj'] = 'rowwise'

Expected behavior

Throws:

(...)
[rank0]:   File "/lustre/fs1/portfolios/coreai/users/gkwasniewski/hf-repo/transformers/src/transformers/models/starcoder2/modeling_starcoder2.py", line 65, in forward
[rank0]:     hidden_states = self.c_proj(hidden_states)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1857, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1805, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py", line 125, in forward
[rank0]:     return F.linear(input, self.weight, self.bias)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 51, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 850, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_dispatch.py", line 160, in dispatch
[rank0]:     self.sharding_propagator.propagate(op_info)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 266, in propagate
[rank0]:     OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__
[rank0]:     return self.cache(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 279, in propagate_op_sharding_non_cached
[rank0]:     out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 126, in _propagate_tensor_meta_non_cached
[rank0]:     fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 27, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1311, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1932, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1414, in _cached_dispatch_impl
[rank0]:     output = self._dispatch_impl(func, types, args, kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2460, in _dispatch_impl
[rank0]:     decomposition_table[func](*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/wrappers.py", line 308, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_decomp/decompositions.py", line 84, in inner
[rank0]:     r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_decomp/decompositions.py", line 1451, in addmm
[rank0]:     out = alpha * torch.mm(mat1, mat2)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 27, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1311, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1932, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1414, in _cached_dispatch_impl
[rank0]:     output = self._dispatch_impl(func, types, args, kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2554, in _dispatch_impl
[rank0]:     r = func(*args, **kwargs)
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/wrappers.py", line 308, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_meta_registrations.py", line 2236, in meta_mm
[rank0]:     torch._check(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1668, in _check
[rank0]:     _check_with(RuntimeError, cond, message)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1650, in _check_with
[rank0]:     raise error_type(message_evaluated)
[rank0]: RuntimeError: a and b must have same reduction dim, but got [3, 9216] X [18432, 4608].
(...)

Expected output:

CausalLMOutputWithPast(loss=None, logits=tensor([[[  0.6951,  -2.9710, -12.8470,  ...,  -4.8511,  -6.0277,  -6.6027],
         [  2.4489,  -0.3970,  -1.9423,  ...,  -3.9063,  -5.0727,  -5.9155],
         [  4.5938,  -0.8972,  -1.5770,  ...,  -4.8748,  -2.2605,  -5.4515]]],

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions