# Task: Fine-Tuning a Vision Language Model with QLoRA

**Selected Model:** [vikhyatk/moondream2](https://huggingface.co/vikhyatk/moondream2) – a small Vision Language Model with approximately 2 billion parameters.  

**Dataset:**  
- Lego Star Wars figures with question-answer pairs  
- Our own data  

**Split:**  
  - **Train:** 575 instances, using 200 (if you run on colab)  
  - **Validation:** 30 instances – Images seen by the model, but not the exact question-answer pairs  
  - **Test:** 30 instances – Neither the images nor the question-answer pairs were seen by the model  


# Installing the required dependencies (takes approximately 4 mins to run, do not worry :D)

In [None]:
!pip install datasets==3.5.0 -q
!pip install bitsandbytes==0.45.5 -q
!pip install transformers==4.51.3 -q
!pip install pillow -q
!pip install torchvision==0.21.0+cu124 -q
!pip install einops -q
!pip install accelerate==1.5.2 -q
!pip install flash-attn==2.7.4.post1 --no-build-isolation -q
!pip install peft==0.14.0 -q
!pip install matplotlib -q
!pip install gdown -q

# Download the data

In [None]:
!gdown 1rSI7swvqU2ZhZqqlMkQYYQiD6Gdq6WQT
!unzip compressed_files.zip

# Import required dependencies

In [None]:
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset,Features,Value
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from peft import get_peft_model, LoraConfig, PeftModel
import gc
import time
from bitsandbytes.optim import Adam32bit
import math
from einops import rearrange
from tqdm import tqdm
import matplotlib.pyplot as plt

# Clear GPU memory (Run it only when your GPU is full)

In [None]:
%reset

In [None]:
def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


clear_memory()

# Initialize the datasets

In [None]:
class LegoStarWarsDataset(Dataset):
    def __init__(self, json_name):
        context_feat = Features({'question': Value(dtype='string', id=None), 'answer': Value(dtype='string', id=None), 'image': Value(dtype='string', id=None)})
        self.data = load_dataset("json", data_files=json_name, features=context_feat)["train"]
        # only use if you run on colab
        if len(self.data) > 200:
            self.data = self.data.select(range(200))
    def __len__(self):
        return len(self.data)

    def read_image(self, image_path):
        image = Image.open(image_path)
        image = image.convert("RGB")
        return image

    def __getitem__(self, idx):
        sample = self.data[idx]
        image  = self.read_image(f"./images_refined/{sample['image']}")
        return {
            "image": image, # Should be a PIL image
            "qa": [
                {
                    "question": sample["question"],
                    "answer": sample["answer"],
                }
            ]
        }


In [None]:
datasets = {
    "train": LegoStarWarsDataset("train.json"),
    "val": LegoStarWarsDataset("val.json"),
    "test": LegoStarWarsDataset("test.json")
}

# Load quantized model
1. Define Quantization config with BitsAndBytes
2. Load model with this config using Huggingface Auto Classes

In [None]:
bnb_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
model_id = "vikhyatk/moondream2"

DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-08-26"

tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", revision=MD_REVISION)
moondream = AutoModelForCausalLM.from_pretrained(
    "vikhyatk/moondream2", revision=MD_REVISION, trust_remote_code=True,
    torch_dtype=DTYPE, device_map={"": DEVICE}, quantization_config=bnb_config
)

In [None]:
moondream.text_model.transformer.gradient_checkpointing_enable()

# Define LoRA config
1. Define Modules from the architecture, which will be fine-tuned
   - The model owners communicated that fine-tuning the vision encoder is often make worse results, so we only fine-tune the text_model
2. Define LoraConfig
3. Update the model using the LoraConfig

In [None]:
from torch.nn import Linear
target_modules = [name for name, module in moondream.text_model.named_modules() if isinstance(module, Linear)]


In [None]:
config = LoraConfig(
        r=32,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules,
)

In [None]:
moondream.text_model = get_peft_model(moondream.text_model, config)
moondream.text_model.print_trainable_parameters()


#### Because we do not want to train the vision encoder, we are disabling the gradient for the vision_encoder

In [None]:
moondream.vision_encoder.requires_grad = False

# Defining the collate_fn and compute_loss function for the model

- collate_fn: it gets a batch of data, and convert this data to a format which is processable for the model
- compute_loss: gets the previously converted data, it makes a forward step and calculates the loss

In [None]:
ANSWER_EOS = "<|endoftext|>"

# Number of tokens used to represent each image.
IMG_TOKENS = 729

def collate_fn(batch):
    images = [sample['image'] for sample in batch]
    images = [moondream.vision_encoder.preprocess(image) for image in images]

    labels_acc = []
    tokens_acc = []

    for sample in batch:
        toks = [tokenizer.bos_token_id]
        labs = [-100] * (IMG_TOKENS + 1)

        for qa in sample['qa']:
            q_t = tokenizer(
                f"\n\nQuestion: {qa['question']}\n\nAnswer:",
                add_special_tokens=False
            ).input_ids
            toks.extend(q_t)
            labs.extend([-100] * len(q_t))

            a_t = tokenizer(
                f" {qa['answer']}{ANSWER_EOS}",
                add_special_tokens=False
            ).input_ids
            toks.extend(a_t)
            labs.extend(a_t)

        tokens_acc.append(toks)
        labels_acc.append(labs)

    max_len = -1
    for labels in labels_acc:
        max_len = max(max_len, len(labels))

    attn_mask_acc = []

    for i in range(len(batch)):
        len_i = len(labels_acc[i])
        pad_i = max_len - len_i

        labels_acc[i].extend([-100] * pad_i)
        tokens_acc[i].extend([tokenizer.eos_token_id] * pad_i)
        attn_mask_acc.append([1] * len_i + [0] * pad_i)

    return (
        images,
        torch.stack([torch.tensor(t, dtype=torch.long) for t in tokens_acc]),
        torch.stack([torch.tensor(l, dtype=torch.long) for l in labels_acc]),
        torch.stack([torch.tensor(a, dtype=torch.bool) for a in attn_mask_acc]),
    )

In [None]:
def compute_loss(batch, model):
    images, tokens, labels, attn_mask = batch

    tokens = tokens.to(DEVICE)
    labels = labels.to(DEVICE)
    attn_mask = attn_mask.to(DEVICE)

    with torch.no_grad():
        img_embs = model.vision_encoder(images)

    tok_embs = model.text_model.get_input_embeddings()(tokens)
    inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)

    outputs = model.text_model(
        inputs_embeds=inputs_embeds,
        labels=labels,
        attention_mask=attn_mask,
    )
    return outputs.loss

# Defining hyperparameters

In [None]:
EPOCHS = 5
BATCH_SIZE = 2
GRAD_ACCUM_STEPS = 8
LR = 1e-3
USE_WANDB = False
PATIENCE = 3

In [None]:
def lr_schedule(step, max_steps):
    x = step / max_steps
    if x < 0.1:
        return 0.1 * LR + 0.9 * LR * x / 0.1
    else:
        return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2

# Initialize the DataLoaders

In [None]:
dataloaders = {
    "train": DataLoader(
        datasets["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
    ),
    "val": DataLoader(
        datasets["val"],
        batch_size=1,
        shuffle=True,
        collate_fn=collate_fn,
    ),
    "test": DataLoader(
        datasets["test"],
        batch_size=1,
        shuffle=True,
        collate_fn=collate_fn,)
}


# Initilizing the optimizer
- we select those params only which have gradients enabled

In [None]:
total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
optimizer = Adam32bit(
    [
        {"params": [p for p in moondream.text_model.parameters() if p.requires_grad]},
    ],
    lr=LR * 0.1,
    betas=(0.9, 0.95),
    eps=1e-6
)
moondream.text_model.train()

# Start the training

- Going trhough each batch in the training data
- Calculate loss, do backward step
- Calculate loss for the validation set
- Iterate until early stop patience or the number of epochs reach tehir limits

In [None]:
checkpoint_name = 'lego_lora_fine_tune_3_lr'

### Approximately one hour for 5 epochs

In [None]:
i = 0
train_losses = []
val_losses = []
best_val_loss = float('inf')
early_stopping_counter = 0

for epoch in range(EPOCHS):
    moondream.text_model.train()
    actual_train_loss = 0
    for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
        i += 1
        loss = compute_loss(batch, moondream)
        loss.backward()

        actual_train_loss += loss.item()


        if i % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

            lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)

    avg_train_loss = actual_train_loss / len(dataloaders['train'])
    print(f"Train Loss: {avg_train_loss:.4f}")

    train_losses.append(avg_train_loss)
    actual_val_loss = 0
    moondream.text_model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloaders["val"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
            val_loss = compute_loss(batch, moondream)
            actual_val_loss += val_loss.item()

    avg_val_loss = actual_val_loss / len(dataloaders["val"])
    print(f"Validation Loss: {avg_val_loss:.4f}")
    val_losses.append(avg_val_loss)
        # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        early_stopping_counter = 0
        moondream.text_model.save_pretrained(f'checkpoints/{checkpoint_name}')
        print("Model improved and saved.")
    else:
        early_stopping_counter += 1
        print(f"Early stopping counter: {early_stopping_counter}/{PATIENCE}")
        if early_stopping_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

print("Training complete.")



In [None]:
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

# Evaluation
- Load the original model, calculate the loss on each dataset
- Load LoRA weigths, merge with the original model, then calculate the same losses
- Make an inference with both model

In [None]:
def test_model(model, dataloader):
        test_loss = 0
        model.eval()
        with torch.no_grad():
            for batch in tqdm(dataloader, desc=f"Test model"):
                actual_test_loss = compute_loss(batch, model)
                test_loss += actual_test_loss.item()

        return test_loss / len(dataloader)

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    "vikhyatk/moondream2", revision=MD_REVISION, trust_remote_code=True,
    torch_dtype=DTYPE,  device_map={"": DEVICE})

In [None]:
base_train_loss = test_model(base_model, dataloaders["train"])
base_val_loss = test_model(base_model, dataloaders["val"])
base_test_loss = test_model(base_model, dataloaders["test"])


In [None]:
image = Image.open("./images_refined/A4-D.png")
question = "Question: Who is this character? \n\n Answer: "
gt = "A4-D"

In [None]:
base_model_answer = base_model.answer_question(
        base_model.encode_image(image),
        question,
        tokenizer=tokenizer,
        num_beams=5,
        no_repeat_ngram_size=5,
        early_stopping=True
    )

In [None]:
base_model.text_model = PeftModel.from_pretrained(base_model.text_model, f'./checkpoints/{checkpoint_name}/', is_trainable=False)

In [None]:
base_model.text_model = base_model.text_model.merge_and_unload()

In [None]:
peft_train_loss = test_model(base_model, dataloaders["train"])
peft_val_loss = test_model(base_model, dataloaders["val"])
peft_test_loss = test_model(base_model, dataloaders["test"])


In [None]:
peft_model_answer = base_model.answer_question(
        base_model.encode_image(image),
        question,
        tokenizer=tokenizer,
        num_beams=5,
        no_repeat_ngram_size=5,
        early_stopping=True
    )

In [None]:
print(f"PROMPT: \n\n {question} \n")
print(f"BASE Model answer: {base_model_answer} \n")
print(f"LoRA fine-tuned Model answer: {peft_model_answer} \n")
print(f"Ground Truth: {gt}")

In [None]:
print("Fine-tuned MODEL")
print(f"LoRA fine-tune model train loss: {peft_train_loss}")
print(f"LoRA fine-tune model val loss: {peft_val_loss}")
print(f"LoRA fine-tune model test loss: {peft_test_loss}")

print("Base MODEL")
print(f"Base model train loss: {base_train_loss}")
print(f"Base model val loss: {base_val_loss}")
print(f"Base model test loss: {base_test_loss}")