diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 901677a604ddc..cc6d2368536e0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2576,6 +2576,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model, dtype=torch_dtype, low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, **kwargs, ) kwargs["max_memory"] = max_memory