From 5092fd8b8f08a46abadba940609235cbb5c030d9 Mon Sep 17 00:00:00 2001 From: Martin Nguyen Date: Wed, 6 Mar 2024 19:02:16 +0000 Subject: [PATCH 1/6] Add Mixtral --- llava/model/__init__.py | 1 + llava/model/builder.py | 7 + llava/model/language_model/llava_mixtral.py | 158 ++++++++++++++++++++ llava/train/train.py | 2 +- scripts/v1_5/pretrain.sh | 16 +- 5 files changed, 175 insertions(+), 9 deletions(-) create mode 100644 llava/model/language_model/llava_mixtral.py diff --git a/llava/model/__init__.py b/llava/model/__init__.py index dbd91789f..2df4b9887 100644 --- a/llava/model/__init__.py +++ b/llava/model/__init__.py @@ -2,5 +2,6 @@ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig + from .language_model.llava_mixtral import LlavaMixtralForCausalLM, LlavaMixtralConfig except: pass diff --git a/llava/model/builder.py b/llava/model/builder.py index e3d50829f..5facba1ab 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -112,6 +112,13 @@ def load_from_hf(repo_id, filename, subfolder=None): low_cpu_mem_usage=True, **kwargs ) + elif 'mix' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = LlavaMixtralForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **kwargs + ) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM.from_pretrained( diff --git a/llava/model/language_model/llava_mixtral.py b/llava/model/language_model/llava_mixtral.py new file mode 100644 index 000000000..629e58642 --- /dev/null +++ b/llava/model/language_model/llava_mixtral.py @@ -0,0 +1,158 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers import AutoConfig, AutoModelForCausalLM, \ + MixtralConfig, MixtralModel, MixtralForCausalLM \ + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM + + +class LlavaMixtralConfig(MixtralConfig): + model_type = "llava_mixtral" + + +class LlavaMixtralModel(LlavaMetaModel, MixtralModel): + config_class = LlavaMixtralConfig + + def __init__(self, config: MixtralConfig): + super(LlavaMixtralModel, self).__init__(config) + + +class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM): + config_class = LlavaMixtralConfig + + def __init__(self, config): + super(MixtralForCausalLM, self).__init__(config) + self.model = LlavaMixtralModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + +AutoConfig.register("llava_mixtral", LlavaMixtralConfig) +AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM) diff --git a/llava/train/train.py b/llava/train/train.py index 477c668b6..4de1cc21f 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -824,7 +824,7 @@ def train(attn_implementation=None): **bnb_model_from_pretrained_args ) else: - model = LlavaLlamaForCausalLM.from_pretrained( + model = LlavaMixtralForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, diff --git a/scripts/v1_5/pretrain.sh b/scripts/v1_5/pretrain.sh index 9316eaa30..b6ec23303 100644 --- a/scripts/v1_5/pretrain.sh +++ b/scripts/v1_5/pretrain.sh @@ -1,8 +1,8 @@ #!/bin/bash deepspeed llava/train/train_mem.py \ - --deepspeed ./scripts/zero2.json \ - --model_name_or_path lmsys/vicuna-13b-v1.5 \ + --deepspeed ./scripts/zero3_offload.json \ + --model_name_or_path models/ura-hcmut/MixSUra \ --version plain \ --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ --image_folder ./playground/data/LLaVA-Pretrain/images \ @@ -13,23 +13,23 @@ deepspeed llava/train/train_mem.py \ --mm_use_im_start_end False \ --mm_use_im_patch_token False \ --bf16 True \ - --output_dir ./checkpoints/llava-v1.5-13b-pretrain \ + --output_dir ./checkpoints/MixSUraV-pt\ --num_train_epochs 1 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 4 \ - --gradient_accumulation_steps 1 \ + --gradient_accumulation_steps 2 \ --evaluation_strategy "no" \ --save_strategy "steps" \ - --save_steps 24000 \ - --save_total_limit 1 \ + --save_steps 200 \ + --save_total_limit 2 \ --learning_rate 1e-3 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --tf32 True \ - --model_max_length 2048 \ + --model_max_length 32768 \ --gradient_checkpointing True \ --dataloader_num_workers 4 \ --lazy_preprocess True \ - --report_to wandb + --report_to neptune From df93321b5d5ba2ee246f0d5b06cace262fb01cd8 Mon Sep 17 00:00:00 2001 From: Martin Nguyen Date: Fri, 8 Mar 2024 09:26:21 +0000 Subject: [PATCH 2/6] Add Gemma --- llava/conversation.py | 11 ++ llava/model/__init__.py | 1 + llava/model/language_model/llava_gemma.py | 158 ++++++++++++++++++++ llava/model/language_model/llava_mixtral.py | 2 +- llava/train/train.py | 26 +++- scripts/v1_5/finetune_task_lora.sh | 76 ++++++++++ scripts/v1_5/pretrain.sh | 66 ++++++++ 7 files changed, 338 insertions(+), 2 deletions(-) create mode 100644 llava/model/language_model/llava_gemma.py diff --git a/llava/conversation.py b/llava/conversation.py index 00c56867d..44c8e6b9e 100644 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -369,6 +369,16 @@ def dict(self): sep="<|im_end|>", ) +conv_gemma_instruct = Conversation( + system="", + roles=("user\n", "model\n"), + version="gemma", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="\n" +) + default_conversation = conv_vicuna_v1 conv_templates = { "default": conv_vicuna_v0, @@ -377,6 +387,7 @@ def dict(self): "vicuna_v1": conv_vicuna_v1, "llama_2": conv_llama_2, "mistral_instruct": conv_mistral_instruct, + "gemma_instruct": conv_gemma_instruct, "chatml_direct": conv_chatml_direct, "mistral_direct": conv_chatml_direct, diff --git a/llava/model/__init__.py b/llava/model/__init__.py index 2df4b9887..4c9bca648 100644 --- a/llava/model/__init__.py +++ b/llava/model/__init__.py @@ -3,5 +3,6 @@ from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig from .language_model.llava_mixtral import LlavaMixtralForCausalLM, LlavaMixtralConfig + from .language_model.llava_gemma import LlavaGemmaForCausalLM, LlavaGemmaConfig except: pass diff --git a/llava/model/language_model/llava_gemma.py b/llava/model/language_model/llava_gemma.py new file mode 100644 index 000000000..68ba44500 --- /dev/null +++ b/llava/model/language_model/llava_gemma.py @@ -0,0 +1,158 @@ +# Copyright 2024 Duc Q. Nguyen +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers import AutoConfig, AutoModelForCausalLM, \ + GemmaConfig, GemmaModel, GemmaForCausalLM + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM + + +class LlavaGemmaConfig(GemmaConfig): + model_type = "llava_gemma" + + +class LlavaGemmaModel(LlavaMetaModel, GemmaModel): + config_class = LlavaGemmaConfig + + def __init__(self, config: GemmaConfig): + super(LlavaGemmaModel, self).__init__(config) + + +class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM): + config_class = LlavaGemmaConfig + + def __init__(self, config): + super(GemmaForCausalLM, self).__init__(config) + self.model = LlavaGemmaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + +AutoConfig.register("llava_gemma", LlavaGemmaConfig) +AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM) diff --git a/llava/model/language_model/llava_mixtral.py b/llava/model/language_model/llava_mixtral.py index 629e58642..48397c331 100644 --- a/llava/model/language_model/llava_mixtral.py +++ b/llava/model/language_model/llava_mixtral.py @@ -1,4 +1,4 @@ -# Copyright 2023 Haotian Liu +# Copyright 2024 Duc Q. Nguyen # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/llava/train/train.py b/llava/train/train.py index 4de1cc21f..155b7db80 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -823,7 +823,15 @@ def train(attn_implementation=None): cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args ) - else: + elif 'mis' in model_args.model_name_or_path.lower(): + model = LlavaMistralForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + elif 'mix' in model_args.model_name_or_path.lower(): model = LlavaMixtralForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, @@ -831,6 +839,22 @@ def train(attn_implementation=None): torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args ) + elif 'gem' in model_args.model_name_or_path.lower(): + model = LlavaGemmaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = LlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, diff --git a/scripts/v1_5/finetune_task_lora.sh b/scripts/v1_5/finetune_task_lora.sh index f11303f29..95a007193 100644 --- a/scripts/v1_5/finetune_task_lora.sh +++ b/scripts/v1_5/finetune_task_lora.sh @@ -35,3 +35,79 @@ deepspeed llava/train/train_mem.py \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to wandb + + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero3_offload.json \ + --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ + --model_name_or_path models/ura-hcmut/MixSUra \ + --version mistral_instruct \ + --data_path ./playground/data/llava_v1_5_mix665k.json \ + --image_folder ./playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --pretrain_mm_mlp_adapter ./checkpoints/MixSUraV-pt/checkpoint-400/mm_projector.bin \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/MixSUraV-sft \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 200 \ + --save_total_limit 2 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to neptune + + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero3_offload.json \ + --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ + --model_name_or_path models/ura-hcmut/GemSUra-7B \ + --version gemma_instruct \ + --data_path ./playground/data/llava_v1_5_mix665k.json \ + --image_folder ./playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --pretrain_mm_mlp_adapter ./checkpoints/GemSUraV-7B-pt/checkpoint-400/mm_projector.bin \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/GemSUraV-7B-sft \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 200 \ + --save_total_limit 2 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 8192 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to neptune \ No newline at end of file diff --git a/scripts/v1_5/pretrain.sh b/scripts/v1_5/pretrain.sh index b6ec23303..ad4c74881 100644 --- a/scripts/v1_5/pretrain.sh +++ b/scripts/v1_5/pretrain.sh @@ -1,5 +1,36 @@ #!/bin/bash +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path lmsys/vicuna-13b-v1.5 \ + --version plain \ + --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ + --image_folder ./playground/data/LLaVA-Pretrain/images \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-13b-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --save_total_limit 1 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb + + deepspeed llava/train/train_mem.py \ --deepspeed ./scripts/zero3_offload.json \ --model_name_or_path models/ura-hcmut/MixSUra \ @@ -33,3 +64,38 @@ deepspeed llava/train/train_mem.py \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to neptune + + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero3_offload.json \ + --model_name_or_path models/ura-hcmut/GemSUra-7B \ + --version plain \ + --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ + --image_folder ./playground/data/LLaVA-Pretrain/images \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --tune_mm_mlp_adapter True \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/GemSUraV-7B-pt\ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 200 \ + --save_total_limit 2 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 8192 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to neptune From 1c4b878c52a07b4fed2fab1097350c8478d247e0 Mon Sep 17 00:00:00 2001 From: Martin Nguyen Date: Sat, 9 Mar 2024 19:05:19 +0000 Subject: [PATCH 3/6] Fix for Gemma's preprocessing --- llava/conversation.py | 13 ++++++- llava/train/train.py | 86 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/llava/conversation.py b/llava/conversation.py index 44c8e6b9e..96b74fef1 100644 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -13,6 +13,7 @@ class SeparatorStyle(Enum): MPT = auto() PLAIN = auto() LLAMA_2 = auto() + GEMMA = auto() @dataclasses.dataclass @@ -70,6 +71,16 @@ def get_prompt(self): ret += role + message + self.sep else: ret += role + elif self.sep_style == SeparatorStyle.GEMMA: + ret = "" + for i, (role, message) in enumerate(messages): + assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..." + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg wrap_inst = lambda msg: f"[INST] {msg} [/INST]" @@ -375,7 +386,7 @@ def dict(self): version="gemma", messages=(), offset=0, - sep_style=SeparatorStyle.MPT, + sep_style=SeparatorStyle.GEMMA, sep="\n" ) diff --git a/llava/train/train.py b/llava/train/train.py index 155b7db80..16dd6bad9 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -584,7 +584,91 @@ def preprocess_mpt( labels=targets, ) +def preprocess_gemma( + sources: List[List[Dict[str, str]]], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False + ) -> Dict: + conv: conversation_lib.Conversation = conversation_lib.default_conversation.copy() + roles: Dict[str, str] = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations: List[str] = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source: List[Dict[str, str]] = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role: str = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + if has_image: + input_ids: torch.Tensor = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids: torch.Tensor = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets: torch.Tensor = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA + + # Mask target + sep: str = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len: int = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds: List[str] = conversation.split(conv.sep) + re_rounds = [] + for conv_idx in range(0, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) + + cur_len = 1 # Ignore + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep # Re-append sep because split on this + # Now "".join(parts)==rou + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 # Ignore + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # Ignore + else: + round_len = len(tokenizer(rou).input_ids) - 1 # Ignore + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore + + round_len += 2 # sep: \n takes 2 tokens + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += round_len + + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"warning: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return dict( + input_ids=input_ids, + labels=targets, + ) + def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, @@ -627,6 +711,8 @@ def preprocess( return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "gemma": + return preprocess_gemma(sources, tokenizer, has_image=has_image) # add end signal and concatenate together conversations = [] for source in sources: From f9aaa1093a5c03d8e6774dcb80c9cd906a976825 Mon Sep 17 00:00:00 2001 From: Martin Nguyen Date: Sun, 10 Mar 2024 16:36:35 +0000 Subject: [PATCH 4/6] Fix llava_gemma loading --- llava/model/builder.py | 88 +++++++++++++++++++++-- llava/model/language_model/llava_gemma.py | 6 +- 2 files changed, 86 insertions(+), 8 deletions(-) diff --git a/llava/model/builder.py b/llava/model/builder.py index 5facba1ab..7b0b5f3fc 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -50,11 +50,53 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l if 'lora' in model_name.lower() and model_base is None: warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') if 'lora' in model_name.lower() and model_base is not None: - from llava.model.language_model.llava_llama import LlavaConfig - lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) print('Loading LLaVA from base model...') - model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + if 'mpt' in model_name.lower(): + from llava.model.language_model.llava_mpt import LlavaMptConfig + lora_cfg_pretrained = LlavaMptConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) + model = LlavaMptForCausalLM.from_pretrained( + model_base, + low_cpu_mem_usage=True, + config=lora_cfg_pretrained, + **kwargs + ) + elif 'mistral' in model_name.lower(): + from llava.model.language_model.llava_mistral import LlavaMistralConfig + lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = LlavaMistralForCausalLM.from_pretrained( + model_base, + low_cpu_mem_usage=True, + config=lora_cfg_pretrained, + **kwargs + ) + elif 'mix' in model_name.lower(): + from llava.model.language_model.llava_mixtral import LlavaMixtralConfig + lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = LlavaMixtralForCausalLM.from_pretrained( + model_base, + low_cpu_mem_usage=True, + config=lora_cfg_pretrained, + **kwargs + ) + elif 'gem' in model_name.lower(): + from llava.model.language_model.llava_gemma import LlavaGemmaConfig + lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = LlavaGemmaForCausalLM.from_pretrained( + model_base, + low_cpu_mem_usage=True, + config=lora_cfg_pretrained, + **kwargs + ) + else: + from llava.model.language_model.llava_llama import LlavaConfig + lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) @@ -93,6 +135,33 @@ def load_from_hf(repo_id, filename, subfolder=None): tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif 'mistral' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = LlavaMistralForCausalLM.from_pretrained( + model_base, + low_cpu_mem_usage=True, + config=cfg_pretrained, + **kwargs + ) + elif 'mix' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = LlavaMixtralForCausalLM.from_pretrained( + model_base, + low_cpu_mem_usage=True, + config=cfg_pretrained, + **kwargs + ) + elif 'gem' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = LlavaGemmaForCausalLM.from_pretrained( + model_base, + low_cpu_mem_usage=True, + config=cfg_pretrained, + **kwargs + ) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) @@ -106,19 +175,26 @@ def load_from_hf(repo_id, filename, subfolder=None): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) elif 'mistral' in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaMistralForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) elif 'mix' in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaMixtralForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) + elif 'gem' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = LlavaGemmaForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **kwargs + ) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = LlavaLlamaForCausalLM.from_pretrained( diff --git a/llava/model/language_model/llava_gemma.py b/llava/model/language_model/llava_gemma.py index 68ba44500..ece9c3e7e 100644 --- a/llava/model/language_model/llava_gemma.py +++ b/llava/model/language_model/llava_gemma.py @@ -47,7 +47,7 @@ def __init__(self, config): self.model = LlavaGemmaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + # Initialize weights and apply final processing self.post_init() @@ -68,6 +68,7 @@ def forward( images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: @@ -98,7 +99,8 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict + return_dict=return_dict, + cache_position=cache_position ) @torch.no_grad() From 0e8f3a66cf7e1a07d8826e28b73e848b44818acb Mon Sep 17 00:00:00 2001 From: Martin Nguyen Date: Mon, 11 Mar 2024 08:34:40 +0000 Subject: [PATCH 5/6] Fix bug for deployment --- llava/model/builder.py | 2 +- llava/serve/cli.py | 8 ++++--- llava/serve/gradio_web_server.py | 30 ++++++++++++++++---------- llava/serve/model_worker.py | 13 ++++++++---- scripts/upload_model.py | 36 ++++++++++++++++++++++++++++++++ 5 files changed, 70 insertions(+), 19 deletions(-) create mode 100644 scripts/upload_model.py diff --git a/llava/model/builder.py b/llava/model/builder.py index 7b0b5f3fc..f0ccba3f0 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -45,7 +45,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l if use_flash_attn: kwargs['attn_implementation'] = 'flash_attention_2' - if 'llava' in model_name.lower(): + if 'llava' in model_name.lower() or 'surav' in model_name.lower(): # Load LLaVA model if 'lora' in model_name.lower() and model_base is None: warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') diff --git a/llava/serve/cli.py b/llava/serve/cli.py index 564eaa5ab..5c2f9e328 100644 --- a/llava/serve/cli.py +++ b/llava/serve/cli.py @@ -30,11 +30,13 @@ def main(args): model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) - + if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" - elif "mistral" in model_name.lower(): + elif "mistral" in model_name.lower() or "mix" in model_name.lower(): conv_mode = "mistral_instruct" + elif "gem" in model_name.lower(): + conv_mode = "gemma_instruct" elif "v1.6-34b" in model_name.lower(): conv_mode = "chatml_direct" elif "v1" in model_name.lower(): @@ -92,7 +94,7 @@ def main(args): input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] - streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) with torch.inference_mode(): output_ids = model.generate( diff --git a/llava/serve/gradio_web_server.py b/llava/serve/gradio_web_server.py index c07efc122..48fa61ad9 100644 --- a/llava/serve/gradio_web_server.py +++ b/llava/serve/gradio_web_server.py @@ -151,7 +151,7 @@ def add_text(state, text, image, image_process_mode, request: gr.Request): return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 -def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): +def http_bot(state, model_selector, temperature, top_p, top_k, max_new_tokens, request: gr.Request): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() model_name = model_selector @@ -166,13 +166,15 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: if "llava" in model_name.lower(): if 'llama-2' in model_name.lower(): template_name = "llava_llama_2" - elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): + elif "mistral" in model_name.lower() or "mix" in model_name.lower(): if 'orca' in model_name.lower(): template_name = "mistral_orca" elif 'hermes' in model_name.lower(): template_name = "chatml_direct" else: template_name = "mistral_instruct" + elif "gem" in model_name.lower(): + template_name = "gemma_instruct" elif 'llava-v1.6-34b' in model_name.lower(): template_name = "chatml_direct" elif "v1" in model_name.lower(): @@ -233,8 +235,10 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p), + "top_k": float(top_k), + "repetition_penalty": 1.0, "max_new_tokens": min(int(max_new_tokens), 1536), - "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, + "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT, SeparatorStyle.GEMMA] else state.sep2, "images": f'List of {len(state.get_images())} images: {all_image_hash}', } logger.info(f"==== request ====\n{pload}") @@ -285,9 +289,12 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: } fout.write(json.dumps(data) + "\n") +# title_markdown = (""" +# # 🌋 LLaVA: Large Language and Vision Assistant +# [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)] +# """) title_markdown = (""" -# 🌋 LLaVA: Large Language and Vision Assistant -[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)] +# URA x Vision: The new multimodal Large Language Models for Vietnamese """) tos_markdown = (""" @@ -314,7 +321,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: def build_demo(embed_mode, cur_dir=None, concurrency_count=10): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) - with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo: + with gr.Blocks(title="URA x Vision", theme=gr.themes.Default(), css=block_css) as demo: state = gr.State() if not embed_mode: @@ -345,8 +352,9 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10): with gr.Accordion("Parameters", open=False) as parameter_row: temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) - top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) - max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, interactive=False, label="Top P",) + top_k = gr.Slider(minimum=1, maximum=50, value=50, step=1, interactive=False, label="Top K",) + max_output_tokens = gr.Slider(minimum=0, maximum=2048, value=512, step=64, interactive=True, label="Max output tokens",) with gr.Column(scale=8): chatbot = gr.Chatbot( @@ -397,7 +405,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10): [state, chatbot, textbox, imagebox] + btn_list ).then( http_bot, - [state, model_selector, temperature, top_p, max_output_tokens], + [state, model_selector, temperature, top_p, top_k, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) @@ -416,7 +424,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10): queue=False ).then( http_bot, - [state, model_selector, temperature, top_p, max_output_tokens], + [state, model_selector, temperature, top_p, top_k, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) @@ -427,7 +435,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10): [state, chatbot, textbox, imagebox] + btn_list ).then( http_bot, - [state, model_selector, temperature, top_p, max_output_tokens], + [state, model_selector, temperature, top_p, top_k, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) diff --git a/llava/serve/model_worker.py b/llava/serve/model_worker.py index 914432989..43aedee64 100644 --- a/llava/serve/model_worker.py +++ b/llava/serve/model_worker.py @@ -157,15 +157,15 @@ def generate_stream(self, params): temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) max_context_length = getattr(model.config, 'max_position_embeddings', 2048) max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) - stop_str = params.get("stop", None) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + stop_str = params.get("stop", "") do_sample = True if temperature > 0.001 else False input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) - keywords = [stop_str] - # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) - streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=15) max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) @@ -178,6 +178,8 @@ def generate_stream(self, params): do_sample=do_sample, temperature=temperature, top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens, streamer=streamer, use_cache=True, @@ -190,6 +192,9 @@ def generate_stream(self, params): generated_text += new_text if generated_text.endswith(stop_str): generated_text = generated_text[:-len(stop_str)] + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" + break + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" def generate_stream_gate(self, params): diff --git a/scripts/upload_model.py b/scripts/upload_model.py new file mode 100644 index 000000000..6e371991b --- /dev/null +++ b/scripts/upload_model.py @@ -0,0 +1,36 @@ +import argparse +from llava.model.builder import load_pretrained_model +from llava.mm_utils import get_model_name_from_path + + +def upload(args): + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, None, model_name, device_map='cpu') + + if args.export_hub_model_id is not None: + model.push_to_hub( + args.export_hub_model_id, + token=args.hf_hub_token, + max_shard_size="{}GB".format(args.export_size), + safe_serialization=(not args.export_legacy_format), + ) + + try: + tokenizer.padding_side = "left" # restore padding side + tokenizer.init_kwargs["padding_side"] = "left" + if args.export_hub_model_id is not None: + tokenizer.push_to_hub(args.export_hub_model_id, token=args.hf_hub_token) + except Exception: + print("Cannot save tokenizer, please copy the files manually.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--export_hub_model_id", type=str, required=True) + parser.add_argument("--hf_hub_token", type=str, required=True) + parser.add_argument("--export_size", type=int, default=5) + parser.add_argument("--export_legacy_format", type=bool, default=False) + args = parser.parse_args() + + upload(args) From 02b89a6cc3702666775d4d1470f617e63e09544b Mon Sep 17 00:00:00 2001 From: sangttruong Date: Tue, 19 Mar 2024 06:51:43 -0700 Subject: [PATCH 6/6] Fix deployment --- llava/model/builder.py | 2 +- llava/serve/model_worker.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/llava/model/builder.py b/llava/model/builder.py index f0ccba3f0..a3c7538e9 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -226,7 +226,7 @@ def load_from_hf(repo_id, filename, subfolder=None): image_processor = None - if 'llava' in model_name.lower(): + if 'llava' in model_name.lower() or 'surav' in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: diff --git a/llava/serve/model_worker.py b/llava/serve/model_worker.py index 43aedee64..bc40198ae 100644 --- a/llava/serve/model_worker.py +++ b/llava/serve/model_worker.py @@ -64,7 +64,7 @@ def __init__(self, controller_addr, worker_addr, logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn) - self.is_multimodal = 'llava' in self.model_name.lower() + self.is_multimodal = 'llava' in self.model_name.lower() or 'surav' in self.model_name.lower() if not no_register: self.register_to_controller() @@ -127,6 +127,7 @@ def generate_stream(self, params): ori_prompt = prompt images = params.get("images", None) num_image_tokens = 0 + if images is not None and len(images) > 0 and self.is_multimodal: if len(images) > 0: if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):