diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 263e4784d92..a5be8cdc519 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -840,6 +840,12 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: or device_map == torch.device("cpu") ) and model.config.model_type == "chatglm": model = model.float() + if ( + not torch.cuda.is_available() + or device_map == "cpu" + or device_map == torch.device("cpu") + ) and model.config.model_type == "mpt": + model.config.architectures = ["MptForCausalLM"] model.eval() model_type = model.config.model_type.replace("_", "-") @@ -1077,6 +1083,7 @@ def calib_func(model): recipes=quantization_config.recipes, example_inputs=example_inputs, ) + model = quantization.fit( model, conf,