diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index f8db2406017c..1a13ade9d453 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -610,7 +610,7 @@ def convert_and_load_state_dict_in_model( tp_plan = tp_plan or {} device_map = device_map or {"": "cpu"} device_map_regex = re.compile( - "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: x.count("."), reverse=True)) + "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True)) ) dtype_plan = dtype_plan or {} weight_mapping = weight_mapping or []