Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 7 additions & 22 deletions neural_compressor/torch/algorithms/weight_only/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,8 +873,6 @@ def _load_remaining_pretrained_weight(self, model):
resolved_archive_file = self.kwargs.pop("resolved_archive_file", None)
torch_dtype = self.kwargs.pop("torch_dtype", torch.float32)
dtype_orig = self.kwargs.pop("dtype_orig", None)
offload_folder = self.kwargs.pop("offload_folder", None)
offload_state_dict = self.kwargs.pop("offload_state_dict", False)

# restore default dtype
if dtype_orig is not None:
Expand All @@ -884,26 +882,13 @@ def _load_remaining_pretrained_weight(self, model):
resolved_archive_file = [resolved_archive_file]
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)

params_dict = {
"model": model,
"state_dict": state_dict,
"start_prefix": "",
"expected_keys": self.loaded_state_dict_keys,
"device_map": {"": self.device},
"offload_folder": offload_folder,
"state_dict_folder": tempfile.mkdtemp() if offload_state_dict else None,
"state_dict_index": {} if offload_state_dict else None,
"dtype": torch_dtype,
"keep_in_fp32_modules": [],
}

import transformers

if transformers.__version__ < "4.45.0":
params_dict["loaded_state_dict_keys"] = self.loaded_state_dict_keys

_load_state_dict_into_meta_model(**params_dict)
keys = list(state_dict.keys())
for k in keys:
if k not in self.loaded_state_dict_keys:
state_dict.pop(k)
for k, v in state_dict.items():
state_dict[k] = v.to(torch_dtype)
model.load_state_dict(state_dict, strict=False, assign=True)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down
Loading