diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 2147a45d7503..aeca68b934a2 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -312,16 +312,16 @@ class ConversionEntry: GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4 -def _materialize_copy(tensor, dtype=None): +def _materialize_copy(tensor, device=None, dtype=None): tensor = tensor[...] - if dtype is not None: - tensor = tensor.to(dtype) + if dtype is not None or device is not None: + tensor = tensor.to(device=device, dtype=dtype) return tensor -def spawn_materialize(thread_pool, tensor, dtype=None) -> Future: +def spawn_materialize(thread_pool, tensor, device=None, dtype=None) -> Future: def _job(): - return _materialize_copy(tensor, dtype) + return _materialize_copy(tensor, device, dtype) return thread_pool.submit(_job) @@ -447,7 +447,10 @@ def convert_and_load_state_dict_in_model( prefix = model.base_model_prefix tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key} - device_map = device_map or {} # {exact_target_key: device} + device_map = device_map or {"": "cpu"} # {exact_target_key: device} + device_map_regex = re.compile( + "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: x.count("."), reverse=True)) + ) dtype_plan = dtype_plan or {} # {glob_pattern: dtype} weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter} meta_model_state_dict = model.state_dict() @@ -534,7 +537,9 @@ def convert_and_load_state_dict_in_model( ) if future is None: # If not TP, async materialize the tensors. TODO handle disk offload? - future = spawn_materialize(thread_pool, tensor, _dtype) + device_match = device_map_regex.match(first_target_key) + param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") + future = spawn_materialize(thread_pool, tensor, param_device, _dtype) entry.collected_tensors[target_key].setdefault(converter_key, []).append(future) # 2. Actually convert the ckpt diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 80656de2fe90..0d936319e926 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4186,9 +4186,6 @@ def _load_pretrained_model( expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model, expanded_device_map, hf_quantizer) - if device_map is None: - device_map = {"": "cpu"} - keys = sorted(device_map.keys(), key=len, reverse=True) tp_plan = getattr(model, "_tp_plan", None) error_msgs = [] @@ -4211,33 +4208,18 @@ def _load_pretrained_model( missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set() else: all_pointer = set() + # Checkpoints are safetensors if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"): - pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")") - if sharded_metadata is None: - k_v_iterator = dict.fromkeys( - safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0].rsplit("/", 1)[1] - ).items() - else: - k_v_iterator = sharded_metadata["weight_map"].items() - merged_state_dict = {} - for k, v in k_v_iterator: - match = pattern.match(k) - if match and match.group(1) != "": - device = device_map[match.group(1)] - else: - device = device_map.get("", "cpu") - if isinstance(device, torch.device): - device = device.index # safetensors only - if device == "disk": - device = "cpu" # we read to cpu to then write to disk - file_pointer = safe_open( - os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device - ) + for file in checkpoint_files: + file_pointer = safe_open(file, framework="pt", device="cpu") all_pointer.add(file_pointer) - merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet + for k in file_pointer.keys(): + merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet + # User passed an explicit state_dict elif state_dict is not None: merged_state_dict = state_dict + # Checkpoints are .bin elif checkpoint_files is not None: merged_state_dict = {} for ckpt_file in checkpoint_files: