In [2]:
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig



In [18]:
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""


In [19]:

model_id = "llava-hf/llava-1.5-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer = tokenizer

model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [20]:
class LLavaDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            messages = example["messages"]
            text = self.processor.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)
            images.append(example["images"][0])

        batch = self.processor(texts, images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        if self.processor.tokenizer.pad_token_id is not None:
            labels[labels == self.processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        return batch

data_collator = LLavaDataCollator(processor)


In [21]:
raw_datasets = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft")
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]


README.md:   0%|          | 0.00/868 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/20 [00:00<?, ?files/s]

train-00000-of-00020.parquet:   0%|          | 0.00/539M [00:00<?, ?B/s]

train-00001-of-00020.parquet:   0%|          | 0.00/547M [00:00<?, ?B/s]

train-00002-of-00020.parquet:   0%|          | 0.00/540M [00:00<?, ?B/s]

train-00003-of-00020.parquet:   0%|          | 0.00/542M [00:00<?, ?B/s]

train-00004-of-00020.parquet:   0%|          | 0.00/541M [00:00<?, ?B/s]

train-00005-of-00020.parquet:   0%|          | 0.00/541M [00:00<?, ?B/s]

train-00006-of-00020.parquet:   0%|          | 0.00/539M [00:00<?, ?B/s]

train-00007-of-00020.parquet:   0%|          | 0.00/540M [00:00<?, ?B/s]

train-00008-of-00020.parquet:   0%|          | 0.00/540M [00:00<?, ?B/s]

train-00009-of-00020.parquet:   0%|          | 0.00/537M [00:00<?, ?B/s]

train-00010-of-00020.parquet:   0%|          | 0.00/537M [00:00<?, ?B/s]

train-00011-of-00020.parquet:   0%|          | 0.00/544M [00:00<?, ?B/s]

train-00012-of-00020.parquet:   0%|          | 0.00/549M [00:00<?, ?B/s]

train-00013-of-00020.parquet:   0%|          | 0.00/543M [00:00<?, ?B/s]

train-00014-of-00020.parquet:   0%|          | 0.00/543M [00:00<?, ?B/s]

train-00015-of-00020.parquet:   0%|          | 0.00/547M [00:00<?, ?B/s]

train-00016-of-00020.parquet:   0%|          | 0.00/541M [00:00<?, ?B/s]

train-00017-of-00020.parquet:   0%|          | 0.00/541M [00:00<?, ?B/s]

train-00018-of-00020.parquet:   0%|          | 0.00/547M [00:00<?, ?B/s]

train-00019-of-00020.parquet:   0%|          | 0.00/540M [00:00<?, ?B/s]

test-00000-of-00002.parquet:   0%|          | 0.00/285M [00:00<?, ?B/s]

test-00001-of-00002.parquet:   0%|          | 0.00/284M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/259155 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/13640 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/23 [00:00<?, ?it/s]

In [35]:
train_dataset[12]

{'messages': [{'content': [{'index': 0, 'text': None, 'type': 'image'},
    {'index': None,
     'text': '\nWhat may be the purpose of this gathering in the field?',
     'type': 'text'}],
   'role': 'user'},
  {'content': [{'index': None,
     'text': 'The purpose of this gathering in the field is for a group of people to enjoy flying a giant lizard-shaped kite together. In the image, there are several individuals, along with the large kite dominating the scene. The lush green field provides an ample space for kite flying activities, allowing the participants to run and maneuver the kite in the open area. This event brings people together for recreational purposes, where they can bond and have fun while engaging in an outdoor activity.',
     'type': 'text'}],
   'role': 'assistant'}],
 'images': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=480x640>]}

In [31]:
args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        #max_steps = 30,
        num_train_epochs = 1, # Set this instead of max_steps for full training runs
        learning_rate = 2e-4,
        #fp16 = not is_bf16_supported(),
        #bf16 = is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",     # For Weights and Biases

        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
        dataset_text_field = "text",
        dataset_kwargs = {"skip_prepare_dataset": True},
        dataset_num_proc = 4,
        # max_seq_length = 2048,
    )

In [None]:

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    #dataset_text_field="text",  # need a dummy field
    tokenizer=tokenizer,
    #data_collator=data_collator,
    #dataset_kwargs={"skip_prepare_dataset": True},
)




  trainer = SFTTrainer(


In [37]:
trainer.train()

ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided ['messages', 'images']