From 960b08b31e37870ccc8bc65527b50000a88edd8f Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 9 Nov 2023 18:45:45 +0100 Subject: [PATCH] fix retie --- src/accelerate/utils/modeling.py | 15 +++++++++++---- tests/test_quantization.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 058b4c855b1..fe8358d0756 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -535,15 +535,22 @@ def retie_parameters(model, tied_params): """ for tied_group in tied_params: param_to_tie = None - # First iteration of the loop will set param_to_tie, next ones will tie it to the others + # two loops : the first one to set param_to_tie , the second one to change the values of tied_group for param_name in tied_group: module = model splits = param_name.split(".") for split in splits[:-1]: module = getattr(module, split) - if param_to_tie is None: - param_to_tie = getattr(module, splits[-1]) - else: + param = getattr(module, splits[-1]) + if param_to_tie is None and param.device != torch.device("meta"): + param_to_tie = param + break + if param_to_tie is not None: + for param_name in tied_group: + module = model + splits = param_name.split(".") + for split in splits[:-1]: + module = getattr(module, split) setattr(module, splits[-1], param_to_tie) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 86d8b6831b3..92ecbc4789c 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -444,7 +444,7 @@ def test_int8_serialization_offload(self): model_8bit_from_saved = load_and_quantize_model( model_8bit_from_saved, bnb_quantization_config, - weights_location=tmpdirname + "/pytorch_model.bin", + weights_location=tmpdirname, device_map=device_map, no_split_module_classes=["BloomBlock"], offload_folder=tmpdirname + "/tmp",