Skip to content

Commit

Permalink
[NeuralChat] support llama series model for llava finetuning. (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk12014402 committed Jan 15, 2024
1 parent a09f92d commit d753cb8
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@

Large Language and Vision Assistant (LLaVA) is a multi-modal training framework that proposed from [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485) and [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/abs/2310.03744). This example demonstrates how to train mult-modal model on Intel Gaudi2.

## Validated Model List
|Pretrained model| LLaVA |
|------------------------------------|---|
|Mistral series||
|LLaMA series||

**Note:** For Salesforce/codegen25-7b-* series models same with LLaMA architecture, need install `pip install transformers==4.33.2` refer [this](https://github.com/salesforce/CodeGen/issues/82)

## Train

LLaVA training consists of two stages: (1) feature alignment stage: use our 558K subset of the LAION-CC-SBU dataset to connect a *frozen pretrained* vision encoder to a *frozen LLM*; (2) visual instruction tuning stage: use 150K GPT-generated multimodal instruction-following data, plus around 515K VQA data from academic-oriented tasks, to teach the model to follow multimodal instructions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@

import transformers

from transformers import AutoTokenizer, set_seed, BitsAndBytesConfig
from transformers import AutoTokenizer, set_seed, BitsAndBytesConfig, AutoConfig
from transformers.integrations.deepspeed import is_deepspeed_available
from intel_extension_for_transformers.transformers.modeling.llava_models import LlavaMistralForCausalLM
from llava_utils import *

if is_hpu_available:
Expand Down Expand Up @@ -133,19 +132,46 @@ def train():
low_cpu_mem_usage = False
device_map = None


model = LlavaMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
quantization_config=quantization_config,
torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)),
trust_remote_code=model_args.trust_remote_code,
use_auth_token=model_args.use_auth_token
)
config_kwargs = {
"cache_dir": training_args.cache_dir,
"trust_remote_code": model_args.trust_remote_code,
}
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)

use_fast = True
if config.architectures[0] == "LlamaForCausalLM":
from intel_extension_for_transformers.transformers.modeling.llava_models.llava_llama \
import LlavaLlamaForCausalLM
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
quantization_config=quantization_config,
torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)),
trust_remote_code=model_args.trust_remote_code,
use_auth_token=model_args.use_auth_token
)
use_fast = False
elif config.architectures[0] == "MistralForCausalLM":
from intel_extension_for_transformers.transformers.modeling.llava_models.llava_mistral \
import LlavaMistralForCausalLM
model = LlavaMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
quantization_config=quantization_config,
torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)),
trust_remote_code=model_args.trust_remote_code,
use_auth_token=model_args.use_auth_token
)
else:
raise ValueError("No llava implemention for the model {}".format(model_args.model_name_or_path))

# for training
model.config.use_cache = False
Expand Down Expand Up @@ -189,7 +215,8 @@ def make_inputs_require_grad(module, input, output):
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
# use_fast=False
trust_remote_code=model_args.trust_remote_code,
use_fast=use_fast
)

tokenizer.pad_token = tokenizer.eos_token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .llava_mistral import LlavaMistralForCausalLM
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast

from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


class LlavaConfig(LlamaConfig):
model_type = "llava"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
config_class = LlavaConfig

def __init__(self, config: LlamaConfig):
super(LlavaLlamaModel, self).__init__(config)


class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaConfig

def __init__(self, config):
super(LlavaLlamaForCausalLM, self).__init__(config)
self.model = LlavaLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_model(self):
return self.model

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:

if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images
)

# pylint: disable=E1101
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
# pylint: disable=E1101
_inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
_inputs['images'] = images
return _inputs

0 comments on commit d753cb8

Please sign in to comment.