# Finetuning-notebook 
- Before you go through this code as mentioned before this project was inspired by Moondream so The fientuing code is similar
- Except in our case we need to finetune both the Language Model and the MLP  OR multimodel-projector

In [None]:
%pip install torch transformers timm einops datasets bitsandbytes accelerate

In [None]:
from torch.utils.data import Dataset
from datasets import load_dataset

class CaptchaDataset(Dataset):
    def __init__(self, split='train'):
        self.data = load_dataset("google/docci", trust_remote_code=True)[split]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "image": sample["image"], # Should be a PIL image
            "qa": [
                {
                    "question": "Describe this image.",
                    "answer": sample["description"],
                }
            ]
        }

datasets = {
    "train": CaptchaDataset("train"),
    "test": CaptchaDataset("test"),
}

In [None]:
from transformers import AutoModelForCausalLM
from PIL import Image

model = AutoModelForCausalLM.from_pretrained("damerajee/GPTVision-1-ft", trust_remote_code=True)
tokenizer = model.tokenizer

In [None]:
# Number of times to repeat the training dataset. Increasing this may cause the model to overfit or
# lose generalization due to catastrophic forgetting. Decreasing it may cause the model to underfit.
EPOCHS = 1

# Number of samples to process in each batch. Set this to the highest value that doesn't cause an
# out-of-memory error. Decrease it if you're running out of memory.
BATCH_SIZE = 8

# Number of batches to process before updating the model. You can use this to simulate a higher batch
# size than your GPU can handle. Set this to 1 to disable gradient accumulation.
GRAD_ACCUM_STEPS = 2

# Learning rate for the Adam optimizer. Needs to be tuned on a case-by-case basis. As a general rule
# of thumb, increase it by 1.4 times each time you double the effective batch size.
#
# Source: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
#
# Note that we linearly warm the learning rate up from 0.1 * LR to LR over the first 10% of the
# training run, and then decay it back to 0.1 * LR over the last 90% of the training run using a
# cosine schedule.
LR = 1e-5
DEVICE = "cuda"
# Whether to use Weights and Biases for logging training metrics.

In [None]:
model.to("cuda")

In [None]:
for param in model.parameters():
    param.requires_grad = True

In [None]:
from torch.utils.data import DataLoader
from bitsandbytes.optim import Adam8bit
import math
import torch
from einops import rearrange
from tqdm.notebooks import tqdm

# Constants
DEVICE = "cuda"
ANSWER_EOS = "<|endoftext|>"
IMG_TOKENS = 197


def collate_fn(batch):
    images = [sample['image'].convert("RGB") for sample in batch]

    labels_acc = []
    tokens_acc = []

    for sample in batch:
        toks = [tokenizer.bos_token_id]
        labs = [-100] * (IMG_TOKENS + 1)

        # Handle multiple QA pairs
        for qa in sample['qa']:
            q_t = tokenizer(
                f"\n\nQuestion: {qa['question']}\n\nAnswer:",
                add_special_tokens=False,
                padding='max_length',
                truncation=True,
                max_length=104,
            ).input_ids
            toks.extend(q_t)
            labs.extend([-100] * len(q_t))

            a_t = tokenizer(
                f" {qa['answer']}{ANSWER_EOS}",
                add_special_tokens=False,
                 padding='max_length',
                truncation=True,
                max_length=720,
            ).input_ids
            toks.extend(a_t)
            labs.extend(a_t)

        tokens_acc.append(toks)
        labels_acc.append(labs)

    max_len = max(len(labels) for labels in labels_acc)

    attn_mask_acc = []

    for i in range(len(batch)):
        len_i = len(labels_acc[i])
        pad_i = max_len - len_i

        labels_acc[i].extend([-100] * pad_i)
        tokens_acc[i].extend([tokenizer.eos_token_id] * pad_i)
        attn_mask_acc.append([1] * len_i + [0] * pad_i)


    return (
        images,
        torch.stack([torch.tensor(t, dtype=torch.long) for t in tokens_acc]),
        torch.stack([torch.tensor(l, dtype=torch.long) for l in labels_acc]),
        torch.stack([torch.tensor(a, dtype=torch.bool) for a in attn_mask_acc]),
    )
def evaluate_model(dataloader):
    model.language_model.eval()
    model.mlp.eval()

    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            images, tokens, labels, attn_mask = batch
            tokens = tokens.to(DEVICE)
            labels = labels.to(DEVICE)
            attn_mask = attn_mask.to(DEVICE)

            # Compute loss
            img_embs = model.vision_encoder(images, device=DEVICE)
            img_embs = model.mlp(img_embs)
            tok_embs = model.language_model.get_input_embeddings()(tokens)
            inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)

            outputs = model.language_model(
                inputs_embeds=inputs_embeds,
                labels=labels,
                attention_mask=attn_mask,
            )

            total_loss += outputs.loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches
    return avg_loss

def compute_loss(batch):
    images, tokens, labels, attn_mask = batch
    
    tokens = tokens.to(DEVICE)
    labels = labels.to(DEVICE)
    attn_mask = attn_mask.to(DEVICE)

    img_embs = model.vision_encoder(images, device=DEVICE)
    
    # Apply MLP to image embeddings
    img_embs = model.mlp(img_embs)

    tok_embs = model.language_model.get_input_embeddings()(tokens)

    
    # Concatenate image embeddings and token embeddings
    inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)



    assert inputs_embeds.shape[1] == attn_mask.shape[1], "Mismatch between embeddings and attention mask length."
    
  

    outputs = model.language_model(
        inputs_embeds=inputs_embeds,
        labels=labels,
        attention_mask=attn_mask,
    )

    return outputs.loss





def lr_schedule(step, max_steps):
    x = step / max_steps
    if x < 0.1:
        return 0.1 * LR + 0.9 * LR * x / 0.1
    else:
        return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2

dataloaders = {
    "train": DataLoader(
        datasets["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
    ),
    "test": DataLoader(
        datasets["test"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
    )
}


total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
model.language_model.train()
model.mlp.train() 
model.language_model.transformer.gradient_checkpointing_enable()

# Modify the optimizer to include MLP parameters
optimizer = Adam8bit(
    [
        {"params": model.language_model.parameters()},
        {"params": model.mlp.parameters()},  
    ],
    lr=LR * 0.1,
    betas=(0.9, 0.95),
    eps=1e-6
)

print("WE ARE STRATING TO TRAIN")
i = 0
for epoch in range(EPOCHS):
    for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
        i += 1

        loss = compute_loss(batch)
        loss.backward()

        if i % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

            lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if i % 10 == 0:
            print(f"Loss: {loss.item()}")
            print("Evaluating model...")
            test_loss = evaluate_model(dataloaders["test"])
            print(f"Test Loss: {test_loss}")

        if i % 500 == 0:
            model.save_pretrained(f"checkpoints/gptvision-ft_{i}")
            
     

In [None]:
model.eval()
from IPython.display import display

sample = datasets['train'][0]
display(sample['image'])

image = sample['image']
image = image.resize((224,224))
image = image.convert("RGB")

gen_kwargs = {
    "do_sample": True,
    "temperature": 0.8,
    "top_p": 0.6,
    "repetition_penalty": 1.6,

    }

for qa in sample['qa']:
    print('Question:', qa['question'])
    print('Ground Truth:', qa['answer'])
    print('GPT-Vision:', model.generate(
        image = image ,
        question ="Describe this image.",
        max_new_tokens=256,
        **gen_kwargs
        
    ))