In [4]:
import torch
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration, BitsAndBytesConfig
from trl import SFTTrainer
from peft import LoraConfig

In [19]:
# from huggingface_hub import notebook_login
# notebook_login()

In [20]:
model_id = "llava-hf/llava-1.5-7b-hf"

In [21]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
)

In [22]:
model = LlavaForConditionalGeneration.from_pretrained(model_id,
                                                      quantization_config=quantization_config,
                                                      torch_dtype=torch.float16)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [23]:
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""

In [24]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer = tokenizer

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [25]:
class LLavaDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            messages = example["messages"]
            text = self.processor.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)
            images.append(example["images"][0])

        batch = self.processor(texts, images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        if self.processor.tokenizer.pad_token_id is not None:
            labels[labels == self.processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        return batch

data_collator = LLavaDataCollator(processor)

In [26]:
from datasets import load_dataset

raw_datasets = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft")
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/23 [00:00<?, ?it/s]

In [27]:
train_dataset[0]

{'messages': [{'content': [{'index': None,
     'text': 'Who wrote this book?\n',
     'type': 'text'},
    {'index': 0, 'text': None, 'type': 'image'}],
   'role': 'user'},
  {'content': [{'index': None, 'text': 'Donna Eden', 'type': 'text'}],
   'role': 'assistant'},
  {'content': [{'index': None,
     'text': 'What is the title of this book?',
     'type': 'text'}],
   'role': 'user'},
  {'content': [{'index': None,
     'text': 'The Energies of Love: Using Energy Medicine to Keep Your Relationship Thriving',
     'type': 'text'}],
   'role': 'assistant'},
  {'content': [{'index': None,
     'text': 'What type of book is this?',
     'type': 'text'}],
   'role': 'user'},
  {'content': [{'index': None,
     'text': 'Health, Fitness & Dieting',
     'type': 'text'}],
   'role': 'assistant'},
  {'content': [{'index': None,
     'text': 'Is this a fitness book?',
     'type': 'text'}],
   'role': 'user'},
  {'content': [{'index': None, 'text': 'Yes', 'type': 'text'}],
   'role': 'assist

In [29]:
training_args = TrainingArguments(
    output_dir="llava-1.5-7b-hf-ft-mix-vsft",
    report_to="tensorboard",
    learning_rate=1.4e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    logging_steps=5,
    num_train_epochs=1,
    push_to_hub=True,
    gradient_checkpointing=True,
    remove_unused_columns=False,
    fp16=True,
    bf16=False
)

In [30]:
lora_config = LoraConfig(
    r=64,
    lora_alpha=16,
    target_modules="all-linear"
)

In [31]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=lora_config,
    dataset_text_field="text",  # need a dummy field
    tokenizer=tokenizer,
    data_collator=data_collator,
    dataset_kwargs={"skip_prepare_dataset": True},
)

In [6]:
%load_ext tensorboard
%tensorboard --logdir="/llavaexp"

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6007 (pid 28796), started 0:20:28 ago. (Use '!kill 28796' to kill it.)

### Start the training!

In [33]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


KeyboardInterrupt: 

### Push the model to the HF Hub

In [None]:
trainer.push_to_hub()