Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
fd8e607
inital commit
ArthurZucker Nov 18, 2025
7990c49
up
ArthurZucker Nov 18, 2025
1a9f77a
update unexpected later on
ArthurZucker Nov 18, 2025
c82b5c8
Merge branch 'main' of github.com:huggingface/transformers into vlm-u…
ArthurZucker Nov 18, 2025
30e405a
fix
ArthurZucker Nov 18, 2025
e9fcb66
update
ArthurZucker Nov 18, 2025
4204535
simplify our lives
ArthurZucker Nov 19, 2025
1da30a6
isolate a bit more
ArthurZucker Nov 19, 2025
5c71300
fixup
ArthurZucker Nov 19, 2025
6c33dc8
small nits
ArthurZucker Nov 19, 2025
e53e1c6
style
ArthurZucker Nov 19, 2025
5e2e0c4
nit
ArthurZucker Nov 19, 2025
eb8493c
fix common cases
ArthurZucker Nov 19, 2025
526001e
Merge branch 'main' of github.com:huggingface/transformers into vlm-u…
ArthurZucker Nov 19, 2025
74c524d
fix post merge
ArthurZucker Nov 19, 2025
7c04b0f
bnb needs missing keys
ArthurZucker Nov 19, 2025
935e77f
small fix
ArthurZucker Nov 19, 2025
6c23f3e
bettrer documentation
ArthurZucker Nov 19, 2025
b5adc5b
no veradict + base class
ArthurZucker Nov 19, 2025
2746e0f
rake review comments
ArthurZucker Nov 19, 2025
b7591da
take all comments
ArthurZucker Nov 19, 2025
138d415
fix super init
ArthurZucker Nov 19, 2025
cb63300
update doc to be more real
ArthurZucker Nov 20, 2025
12a74d9
up
MekkCyber Nov 19, 2025
1b6f64d
fix some tests
MekkCyber Nov 19, 2025
fe089ea
weight convertor
MekkCyber Nov 19, 2025
f1e8731
up
MekkCyber Nov 20, 2025
9607be8
mostly correct
MekkCyber Nov 20, 2025
4cc5912
oups
MekkCyber Nov 20, 2025
da2d82e
skip non linears
MekkCyber Nov 20, 2025
4e7903f
only some tests to go
MekkCyber Nov 20, 2025
6dfc5b1
need quantization
MekkCyber Nov 20, 2025
4842c3a
fix tests
MekkCyber Nov 20, 2025
18e677d
Merge branch 'main' into fix-torchao-v2
MekkCyber Nov 20, 2025
81daea7
rm comment
MekkCyber Nov 20, 2025
a59b4d3
revert
MekkCyber Nov 20, 2025
c0f44e4
revert 2
MekkCyber Nov 20, 2025
ac87da6
style
MekkCyber Nov 20, 2025
1900da8
up
MekkCyber Nov 20, 2025
a1f575e
up
MekkCyber Nov 20, 2025
6261f34
remove unsafe loading path
MekkCyber Nov 20, 2025
d58a69b
fix
MekkCyber Nov 20, 2025
33a3034
fix
MekkCyber Nov 20, 2025
29fe298
fix
MekkCyber Nov 21, 2025
b673170
up
MekkCyber Nov 21, 2025
81f0bc4
rm Dtensor import
MekkCyber Nov 21, 2025
955b7b5
rm replicate import
MekkCyber Nov 21, 2025
8a8ab52
fix imports
MekkCyber Nov 21, 2025
cbcf907
up
MekkCyber Nov 21, 2025
564fbd3
minor modifications
SunMarc Nov 21, 2025
10ed7b9
add quantizaton_operation
MekkCyber Nov 21, 2025
dfc2fc8
delattr
SunMarc Nov 21, 2025
df5a5fc
Merge remote-tracking branch 'upstream/fix-torchao-v2' into fix-torch…
SunMarc Nov 21, 2025
f90d29b
fix
SunMarc Nov 21, 2025
bff9528
fix get_param_buffer
SunMarc Nov 21, 2025
07cc7fa
better to just set module initialized
SunMarc Nov 21, 2025
e713f43
rm tie_weights
MekkCyber Nov 21, 2025
e7aae58
guard imports
MekkCyber Nov 21, 2025
edcd1a0
up
MekkCyber Nov 21, 2025
cabd4e2
rm offloading for now
MekkCyber Nov 24, 2025
7043cb8
add license
MekkCyber Nov 24, 2025
396e06f
don't guard torch
MekkCyber Nov 24, 2025
d215565
comment
MekkCyber Nov 24, 2025
56f695f
fix
MekkCyber Nov 24, 2025
0dd7a35
rm torch.grad
MekkCyber Nov 24, 2025
1f904ed
revert
MekkCyber Nov 24, 2025
2bcc1a7
Merge branch 'main' into fix-torchao-v2
MekkCyber Nov 25, 2025
c63dfb9
fix
MekkCyber Nov 25, 2025
755a2ba
add guard
MekkCyber Nov 25, 2025
8ebf64f
add second guard
MekkCyber Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def convert(
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
model,
missing_keys,
config,
**kwargs,
Comment on lines 115 to 121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of the args should be optional kwargs so that we can clean the other convert function with **kwargs and only put args that are being used but that's fine, we should do that later on

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

) -> dict[str, list[torch.Tensor]]:
Expand All @@ -138,6 +140,8 @@ def convert(
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
model,
missing_keys,
config,
) -> dict[str, list[torch.Tensor]]:
tensors = next(iter(value.values()))
Expand All @@ -163,6 +167,8 @@ def convert(
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
model,
missing_keys,
config,
) -> dict[str, torch.Tensor]:
if len(target_keys) != 1:
Expand Down Expand Up @@ -191,6 +197,8 @@ def convert(
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
model,
missing_keys,
config,
) -> dict[str, torch.Tensor]:
merged: dict[str, torch.Tensor] = {}
Expand Down Expand Up @@ -220,6 +228,8 @@ def convert(
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
model,
missing_keys,
config,
) -> dict[str, list[torch.Tensor]]:
if len(value) != len(self.sizes):
Expand Down Expand Up @@ -258,6 +268,8 @@ def convert(
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
model,
missing_keys,
config,
) -> dict[str, list[torch.Tensor]]:
self.config = config
Expand Down Expand Up @@ -298,21 +310,28 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu
class WeightRenaming(WeightTransform):
# Special case of WeightTransform that only renames keys without any conversion.

def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None):
def convert(
self,
layer_name: str,
model=None,
config=None,
hf_quantizer=None,
missing_keys: Optional[MutableSet[str]] = None,
):
misc = {}
for pattern, futures in self.collected_tensors.items():
self.collected_tensors[pattern] = [future.result() for future in futures]

collected_tensors = self.collected_tensors
if quantizer is not None and self.quantization_operation is not None:
if hf_quantizer is not None and self.quantization_operation is not None:
with log_to_misc(layer_name, misc, (self.collected_tensors, layer_name), self.quantization_operation):
collected_tensors = self.quantization_operation.convert(
self.collected_tensors,
source_keys=self.source_keys,
target_keys=self.target_keys,
full_layer_name=layer_name,
model=model,
config=config,
quant_config=quantizer.quantization_config,
missing_keys=missing_keys,
)

Expand All @@ -332,7 +351,14 @@ def __post_init__(self):
if not self.operations:
raise ValueError("WeightConverter requires at least one operation.")

def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None):
def convert(
self,
layer_name: str,
model=None,
config=None,
hf_quantizer=None,
missing_keys: Optional[MutableSet[str]] = None,
):
misc = {}
for pattern, futures in self.collected_tensors.items():
self.collected_tensors[pattern] = [future.result() for future in futures]
Expand All @@ -345,17 +371,19 @@ def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Op
source_keys=self.source_keys,
target_keys=self.target_keys,
full_layer_name=layer_name,
model=model,
config=config,
missing_keys=missing_keys,
)
if quantizer is not None and self.quantization_operation is not None:
if hf_quantizer is not None and self.quantization_operation is not None:
with log_to_misc(layer_name, misc, (collected_tensors, layer_name), self.quantization_operation):
collected_tensors = self.quantization_operation.convert(
collected_tensors,
source_keys=self.source_keys,
target_keys=self.target_keys,
full_layer_name=layer_name,
config=config,
quant_config=quantizer.quantization_config,
model=model,
missing_keys=missing_keys,
)
return collected_tensors, misc
Expand Down Expand Up @@ -626,7 +654,6 @@ def convert_and_load_state_dict_in_model(
```

"""

prefix = model.base_model_prefix
tp_plan = tp_plan or {}
device_map = device_map or {"": "cpu"}
Expand Down Expand Up @@ -750,7 +777,11 @@ def convert_and_load_state_dict_in_model(
pbar.refresh()
try:
realized_value, misc = mapping.convert(
first_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys
first_param_name,
model=model,
config=model.config,
hf_quantizer=hf_quantizer,
missing_keys=missing_keys,
)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def all_tensors():
if name in tied_keys:
continue
if hf_quantizer is not None:
dtype_size = hf_quantizer.param_element_size(model, name)
dtype_size = hf_quantizer.param_element_size(model, name, param)
else:
dtype_size = param.element_size()
size = param.numel() * dtype_size
Expand Down
14 changes: 12 additions & 2 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs
self,
input_dict: dict[str, list[torch.Tensor]],
model: Optional[torch.nn.Module] = None,
missing_keys=None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""
we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.
Expand All @@ -59,6 +63,7 @@ def convert(
# remove missing keys that were create when initializing Params4bit
for key in new_value.quant_state.as_dict(packed=True).keys():
missing_keys.discard(f"{full_name}.{key}")
module._is_hf_initialized = True
return {target_key: new_value}
else:
module_name = target_key.rsplit(".", 1)[0]
Expand All @@ -77,6 +82,7 @@ def convert(
device=value.device,
module=module,
)
module._is_hf_initialized = True
del self.hf_quantizer.param_quant_stats[module_name]
return {target_key: new_value}
return {}
Expand All @@ -87,7 +93,11 @@ def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs
self,
input_dict: dict[str, list[torch.Tensor]],
model: Optional[torch.nn.Module] = None,
missing_keys=None,
**kwargs,
) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
Expand Down
47 changes: 26 additions & 21 deletions src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@
from functools import partial, reduce
from typing import Optional

import torch
import torch.distributed as dist
from torch import nn
from ..utils.import_utils import is_torch_available


if is_torch_available():
import torch
import torch.distributed as dist
from torch import nn

from ..distributed import DistributedConfig
from ..utils import is_torch_greater_or_equal, logging
Expand All @@ -31,12 +35,12 @@

logger = logging.get_logger(__name__)

# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()

if is_torch_available():
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()

if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard


def initialize_tensor_parallelism(
Expand Down Expand Up @@ -169,19 +173,20 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
return None


str_to_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
if is_torch_available():
str_to_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}


def get_packed_weights(param, empty_param, device_mesh, rank, dim):
Expand Down
Loading