From 1487a5a6b8f953660d9d944ee6d311774e0e3d53 Mon Sep 17 00:00:00 2001 From: Junlong Li <45759388+lockon-n@users.noreply.github.com> Date: Sat, 9 Mar 2024 02:11:51 +0800 Subject: [PATCH] Update llava_arch.py Fix. --- llava/model/llava_arch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index 7b299d3c4..cc7910f2a 100644 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -25,6 +25,7 @@ from llava.mm_utils import get_anyres_image_grid_shape +from transformers.integrations import is_deepspeed_zero3_enabled class LlavaMetaModel: @@ -94,7 +95,12 @@ def initialize_vision_modules(self, model_args, fsdp=None): def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} - self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + if is_deepspeed_zero3_enabled(): + import deepspeed + with deepspeed.zero.GatheredParameters(self.mm_projector.parameters(), modifier_rank=0): + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + else: + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) def unpad_image(tensor, original_size):