diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index b2dec22f584..b3c6051ba4b 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -873,8 +873,6 @@ def _load_remaining_pretrained_weight(self, model): resolved_archive_file = self.kwargs.pop("resolved_archive_file", None) torch_dtype = self.kwargs.pop("torch_dtype", torch.float32) dtype_orig = self.kwargs.pop("dtype_orig", None) - offload_folder = self.kwargs.pop("offload_folder", None) - offload_state_dict = self.kwargs.pop("offload_state_dict", False) # restore default dtype if dtype_orig is not None: @@ -884,26 +882,13 @@ def _load_remaining_pretrained_weight(self, model): resolved_archive_file = [resolved_archive_file] for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - - params_dict = { - "model": model, - "state_dict": state_dict, - "start_prefix": "", - "expected_keys": self.loaded_state_dict_keys, - "device_map": {"": self.device}, - "offload_folder": offload_folder, - "state_dict_folder": tempfile.mkdtemp() if offload_state_dict else None, - "state_dict_index": {} if offload_state_dict else None, - "dtype": torch_dtype, - "keep_in_fp32_modules": [], - } - - import transformers - - if transformers.__version__ < "4.45.0": - params_dict["loaded_state_dict_keys"] = self.loaded_state_dict_keys - - _load_state_dict_into_meta_model(**params_dict) + keys = list(state_dict.keys()) + for k in keys: + if k not in self.loaded_state_dict_keys: + state_dict.pop(k) + for k, v in state_dict.items(): + state_dict[k] = v.to(torch_dtype) + model.load_state_dict(state_dict, strict=False, assign=True) # make sure token embedding weights are still tied if needed model.tie_weights()