Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Adding support for Mixtral and Gemma models #1247

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions llava/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SeparatorStyle(Enum):
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
GEMMA = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -70,6 +71,16 @@ def get_prompt(self):
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.GEMMA:
ret = ""
for i, (role, message) in enumerate(messages):
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
Expand Down Expand Up @@ -369,6 +380,16 @@ def dict(self):
sep="<|im_end|>",
)

conv_gemma_instruct = Conversation(
system="",
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
version="gemma",
messages=(),
offset=0,
sep_style=SeparatorStyle.GEMMA,
sep="<end_of_turn>\n"
)

default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
Expand All @@ -377,6 +398,7 @@ def dict(self):
"vicuna_v1": conv_vicuna_v1,
"llama_2": conv_llama_2,
"mistral_instruct": conv_mistral_instruct,
"gemma_instruct": conv_gemma_instruct,
"chatml_direct": conv_chatml_direct,
"mistral_direct": conv_chatml_direct,

Expand Down
2 changes: 2 additions & 0 deletions llava/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
from .language_model.llava_mixtral import LlavaMixtralForCausalLM, LlavaMixtralConfig
from .language_model.llava_gemma import LlavaGemmaForCausalLM, LlavaGemmaConfig
except:
pass
97 changes: 90 additions & 7 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,58 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
if use_flash_attn:
kwargs['attn_implementation'] = 'flash_attention_2'

if 'llava' in model_name.lower():
if 'llava' in model_name.lower() or 'surav' in model_name.lower():
# Load LLaVA model
if 'lora' in model_name.lower() and model_base is None:
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
if 'lora' in model_name.lower() and model_base is not None:
from llava.model.language_model.llava_llama import LlavaConfig
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
print('Loading LLaVA from base model...')
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
if 'mpt' in model_name.lower():
from llava.model.language_model.llava_mpt import LlavaMptConfig
lora_cfg_pretrained = LlavaMptConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
model = LlavaMptForCausalLM.from_pretrained(
model_base,
low_cpu_mem_usage=True,
config=lora_cfg_pretrained,
**kwargs
)
elif 'mistral' in model_name.lower():
from llava.model.language_model.llava_mistral import LlavaMistralConfig
lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaMistralForCausalLM.from_pretrained(
model_base,
low_cpu_mem_usage=True,
config=lora_cfg_pretrained,
**kwargs
)
elif 'mix' in model_name.lower():
from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaMixtralForCausalLM.from_pretrained(
model_base,
low_cpu_mem_usage=True,
config=lora_cfg_pretrained,
**kwargs
)
elif 'gem' in model_name.lower():
from llava.model.language_model.llava_gemma import LlavaGemmaConfig
lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaGemmaForCausalLM.from_pretrained(
model_base,
low_cpu_mem_usage=True,
config=lora_cfg_pretrained,
**kwargs
)
else:
from llava.model.language_model.llava_llama import LlavaConfig
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)

token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
Expand Down Expand Up @@ -93,6 +135,33 @@ def load_from_hf(repo_id, filename, subfolder=None):
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
elif 'mistral' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(
model_base,
low_cpu_mem_usage=True,
config=cfg_pretrained,
**kwargs
)
elif 'mix' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaMixtralForCausalLM.from_pretrained(
model_base,
low_cpu_mem_usage=True,
config=cfg_pretrained,
**kwargs
)
elif 'gem' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaGemmaForCausalLM.from_pretrained(
model_base,
low_cpu_mem_usage=True,
config=cfg_pretrained,
**kwargs
)
else:
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
Expand All @@ -106,12 +175,26 @@ def load_from_hf(repo_id, filename, subfolder=None):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
elif 'mistral' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = LlavaMistralForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
elif 'mix' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = LlavaMixtralForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
elif 'gem' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = LlavaGemmaForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained(
Expand Down Expand Up @@ -143,7 +226,7 @@ def load_from_hf(repo_id, filename, subfolder=None):

image_processor = None

if 'llava' in model_name.lower():
if 'llava' in model_name.lower() or 'surav' in model_name.lower():
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
Expand Down
160 changes: 160 additions & 0 deletions llava/model/language_model/llava_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 Duc Q. Nguyen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from transformers import AutoConfig, AutoModelForCausalLM, \
GemmaConfig, GemmaModel, GemmaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


class LlavaGemmaConfig(GemmaConfig):
model_type = "llava_gemma"


class LlavaGemmaModel(LlavaMetaModel, GemmaModel):
config_class = LlavaGemmaConfig

def __init__(self, config: GemmaConfig):
super(LlavaGemmaModel, self).__init__(config)


class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaGemmaConfig

def __init__(self, config):
super(GemmaForCausalLM, self).__init__(config)
self.model = LlavaGemmaModel(config)

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_model(self):
return self.model

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:

if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)

return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position
)

@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")

if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)

return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)

def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs

AutoConfig.register("llava_gemma", LlavaGemmaConfig)
AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
Loading