## Set up the environment

Install the necessary libraries and dependencies for fine-tuning VLMs.


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%pip install transformers peft bitsandbytes accelerate scikit-learn

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2


In [3]:
%pip install -U bitsandbytes



In [4]:
import transformers
import peft
import bitsandbytes
import accelerate
import sklearn

print(f"Transformers version: {transformers.__version__}")
print(f"PEFT version: {peft.__version__}")
print(f"Bitsandbytes version: {bitsandbytes.__version__}")
print(f"Accelerate version: {accelerate.__version__}")
print(f"sklearn version: {sklearn.__version__}")

Transformers version: 4.57.2
PEFT version: 0.18.0
Bitsandbytes version: 0.48.2
Accelerate version: 1.12.0
sklearn version: 1.6.1


## Load the pre-trained model


In [5]:
from transformers import LlavaForConditionalGeneration, LlavaProcessor, BitsAndBytesConfig
import torch

# Load the pre-trained LLaVA model and processor
model_name = "llava-hf/llava-1.5-7b-hf"

# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.float16,
    bnb_8bit_use_double_downcast=True,
)

processor = LlavaProcessor.from_pretrained(model_name, use_fast=True)
model = LlavaForConditionalGeneration.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    dtype=torch.float16,
    torch_dtype=torch.float16,
    )

# model.language_model.lm_head.weight.requires_grad = True

print(f"Model and processor for {model_name} loaded successfully.")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/674 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


config.json:   0%|          | 0.00/950 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

Model and processor for llava-hf/llava-1.5-7b-hf loaded successfully.


## Load and Prepare the Dataset

Load the dataset and split it into training, validation, and test sets.


In [6]:
import json
import os
from PIL import Image
from sklearn.model_selection import train_test_split

# Load the dataset from the specified path
dataset_path = "/content/drive/MyDrive/large/large.json"
with open(dataset_path, "r") as f:
    dataset = json.load(f)

# Define the base directory for images
image_base_dir = "/content/drive/MyDrive/large/" # The image paths in the JSON are relative to this directory

label_to_token_ids = {
    "asleep": "A",
    "awake/peaceful": "B",
    "awake/crying": "C",
    "not-present": "D"
}
# Update image paths to be absolute
for entry in dataset:
    entry["image"] = os.path.join(image_base_dir, entry["image"])
    entry["answer"] = label_to_token_ids[entry["label"]]


# Split the dataset into training (80%), validation (10%), and test (10%) sets
train_data, temp_data = train_test_split(dataset, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)


print(f"Training data size: {len(train_data)}")
print(f"Validation data size: {len(val_data)}")
print(f"Test data size: {len(test_data)}")

# Display the first few entries of the training data
display(train_data[:4])

Training data size: 384
Validation data size: 48
Test data size: 48


[{'id': 'synth_00133',
  'image': '/content/drive/MyDrive/large/images/synth_00133.png',
  'prompt': 'overhead crib view of a gentle smile Asian 6-month-old baby diaper only in a bassinet, awake but content, nighttime with nightlight, teddy bear in corner.',
  'label': 'awake/peaceful',
  'conversations': [{'from': 'human',
    'value': "<image>\nClassify the baby's state in the crib: asleep, awake/peaceful, awake/crying, or not-present."},
   {'from': 'gpt', 'value': 'awake/peaceful'}],
  'answer': 'B'},
 {'id': 'synth_00229',
  'image': '/content/drive/MyDrive/large/images/synth_00229.png',
  'prompt': 'angled from baby monitor camera of a gentle smile Hispanic 3-month-old baby wrapped in swaddle in a bassinet, awake but content, dim nursery light, teddy bear in corner.',
  'label': 'awake/peaceful',
  'conversations': [{'from': 'human',
    'value': "<image>\nClassify the baby's state in the crib: asleep, awake/peaceful, awake/crying, or not-present."},
   {'from': 'gpt', 'value': '

In [7]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoTokenizer
import os
import torch

class LlavaDataset(Dataset):
    def __init__(self, data, processor):
        self.data = data
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
      item = self.data[idx]
      image = Image.open(item["image"]).convert("RGB")
      answer = item["answer"]  # "A", "B", etc. (w/o space!)

      # Resize the image to 1/4 of its original size
      # original_width, original_height = image.size
      new_width = 336
      new_height = 336
      image = image.resize((new_width, new_height))

      # IMPORTANT: exact format that LLaVA expects
      messages = [
          {
              "role": "user",
              "content": [
                  {"type": "image"},
                  {"type": "text", "text": "Classify the baby's state, return a single letter: A (asleep), B (awake/peaceful), C (awake/crying), D (not-present)."}
              ]
          },
          {
              "role": "assistant",
              "content": [ { "type": "text", "text": answer } ]  # already has leading space
          }
      ]

      text = self.processor.apply_chat_template(messages, add_generation_prompt=False),

      batch = self.processor(text=text, images=image, return_tensors="pt", padding=False)
      batch = {k: v.squeeze(0) for k, v in batch.items()}

      labels = batch["input_ids"].clone()


      # Tokenize just the answer part (with leading space)
      answer_ids = self.processor.tokenizer.encode(answer, add_special_tokens=False)

      # Find where the answer starts (search for the exact sequence in reverse)
      input_ids = batch["input_ids"].tolist()
      # print("FOO", self.processor.tokenizer.decode(input_ids))F

      start_idx = None
      for i in range(len(input_ids) - len(answer_ids), -1, -1): # Iterate in reverse
          if input_ids[i:i+len(answer_ids)] == answer_ids:
              start_idx = i
              break

      if start_idx is None:
          print(f"WARNING: Answer '{answer}' not found in:\n{self.processor.tokenizer.decode(input_ids)}")
          labels[:] = -100  # fallback
      else:
          labels[:] = -100
          labels[start_idx: start_idx + len(answer_ids)] = torch.tensor(answer_ids)

      batch["labels"] = labels

      return batch



train_dataset = LlavaDataset(train_data, processor)
val_dataset = LlavaDataset(val_data, processor)
test_dataset = LlavaDataset(test_data, processor)

# Set data_collator to the defined collator
# data_collator = None


# train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
# val_dataloader = DataLoader(val_dataset, batch_size=4, collate_fn=data_collator)
# test_dataloader = DataLoader(test_dataset, batch_size=4, collate_fn=data_collator)

print("Datasets, custom data collator, and DataLoaders created successfully.")

Datasets, custom data collator, and DataLoaders created successfully.


## Set up LoRA and Prepare the Model

Define the LoRA configuration and apply it to the pre-trained LLaVA model.

In [8]:
from peft import LoraConfig, get_peft_model
import torch

# Define LoRA configuration
lora_config = LoraConfig(
    r=64,
    lora_alpha=64,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head"]
)

# Get the LoRA-adapted model
model = get_peft_model(model, lora_config)

# Print the trainable parameters
model.print_trainable_parameters()

print("LoRA configuration set up and model prepared.")



trainable params: 171,180,032 || all params: 7,234,607,104 || trainable%: 2.3661
LoRA configuration set up and model prepared.


## Define Training Arguments and Set up the Trainer

Define the arguments for the training process and set up the `Trainer` object.

In [9]:
from transformers import TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
import torch # Import torch here

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    if isinstance(logits, tuple):
        logits = logits[0]  # Trainer sometimes wraps in tuple

    # Convert to torch tensors
    logits = torch.from_numpy(logits)
    labels = torch.from_numpy(labels)

    # Get predictions where labels are not -100
    pred_ids = torch.argmax(logits, dim=-1)
    mask = labels != -100

    predicted = pred_ids[mask]
    truth = labels[mask]

    acc = 0.0
    if len(predicted) > 0:
        # Decode single tokens (which should be the answer tokens)
        pred_labels = [processor.tokenizer.decode(p.item(), skip_special_tokens=True).strip() for p in predicted]
        true_labels = [processor.tokenizer.decode(t.item(), skip_special_tokens=True).strip() for t in truth]

        acc = accuracy_score(true_labels, pred_labels)

        # Optional: print first few for sanity
        print(f"\nMetrics computed on {len(predicted)} valid answer tokens.")
        print(f"First 8 → Pred: {pred_labels[:8]}")
        print(f"          True: {true_labels[:8]}\n")
    else:
        print("\nNo valid answer tokens found for metric computation in this batch.\n")


    return {"accuracy": acc}


# Define training arguments
training_args = TrainingArguments(
    output_dir="./llava-finetuned",  # Output directory for checkpoints and logs

    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=32,

    num_train_epochs=20,
    learning_rate=1e-4,

    logging_dir="./logs",  # Directory for storing logs
    logging_steps=5,
    eval_strategy="steps", # Evaluate every N steps
    eval_steps=5,
    save_strategy="epoch", # Save checkpoint every epoch
    save_total_limit=3,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    lr_scheduler_type="cosine",

    remove_unused_columns=False,
    push_to_hub=False,
    report_to="none",

    fp16=False,
    bf16=True,

)

# Set up the Trainer
trainer = Trainer(
    model=model,                         # The model to fine-tune
    args=training_args,                  # The training arguments
    train_dataset=train_dataset,         # The training dataset
    eval_dataset=val_dataset,            # The validation dataset
    data_collator=None,
    compute_metrics=compute_metrics,

)

print("Training arguments defined and Trainer set up successfully.")

Training arguments defined and Trainer set up successfully.


## Start Training

Initiate the fine-tuning process using the configured `Trainer` object.

In [11]:
# Start training
trainer.train()

print("Training finished.")

KeyboardInterrupt: 