From bb69451e43fa64ab2412a3773da96ba61ec1b690 Mon Sep 17 00:00:00 2001 From: mosheber Date: Mon, 26 Feb 2024 15:51:39 +0200 Subject: [PATCH] added configurable torch_dtype in model builder --- llava/model/builder.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/llava/model/builder.py b/llava/model/builder.py index e3d50829f..d03d580a5 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -25,6 +25,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): kwargs = {"device_map": device_map, **kwargs} + kwargs["device_map"] = kwargs["device_map"] if kwargs["device_map"] is not None else "auto" if device != "cuda": kwargs['device_map'] = {"": device} @@ -39,8 +40,10 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) - else: - kwargs['torch_dtype'] = torch.float16 + + kwargs['torch_dtype'] = kwargs.get('torch_dtype',torch.float16) + + torch_dtype_to_use = kwargs['torch_dtype'] if use_flash_attn: kwargs['attn_implementation'] = 'flash_attention_2' @@ -99,7 +102,7 @@ def load_from_hf(repo_id, filename, subfolder=None): model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') - mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} + mm_projector_weights = {k: v.to(torch_dtype_to_use) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: if 'mpt' in model_name.lower(): @@ -130,8 +133,8 @@ def load_from_hf(repo_id, filename, subfolder=None): model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() - print('Convert to FP16...') - model.to(torch.float16) + print('Convert to specified dtype...') + model.to(torch_dtype_to_use) else: use_fast = False if 'mpt' in model_name.lower(): @@ -156,12 +159,12 @@ def load_from_hf(repo_id, filename, subfolder=None): if not vision_tower.is_loaded: vision_tower.load_model(device_map=device_map) if device_map != 'auto': - vision_tower.to(device=device_map, dtype=torch.float16) + vision_tower.to(device=device_map, dtype=torch_dtype_to_use) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 - + return tokenizer, model, image_processor, context_len