## Imports & Setup

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from tqdm import tqdm
from helpers import build_prompt, build_message
from prefix_tuning import PrefixTuningModelPastKeyValues, PrefixDataset, prefix_collate
import prefix_tuning
import helpers
import importlib
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor
import gc
importlib.reload(helpers)
importlib.reload(prefix_tuning)

In [None]:
torch.cuda.empty_cache()

In [None]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device("mps")
print('Using device:', device)
print()


#device = torch.device('cuda:0,1' if torch.cuda.is_available() else 'cpu')
#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

# CONSTANTS
NUM_EPOCHS_FT = 100
NUM_EPOCHS_KD = 100
BATCH_SIZE = 16

In [None]:
device

In [None]:
!nvidia-smi

## Get Data and preprocess it

In [None]:
### train data
# data with label and image data
df_train_label = pd.DataFrame(load_dataset('derek-thomas/ScienceQA', split='train'))

df_train_label = df_train_label[df_train_label['solution'] != ''].reset_index()
df_train_label['image'] = df_train_label.apply(lambda row: row['image'] if row['image'] else Image.new("RGB", (224, 224), (0, 0, 0)), axis=1)
df_train_label['input'] = df_train_label.apply(lambda row: build_prompt(row)[0], axis=1)
df_train_label['message'] = df_train_label.apply(lambda row: build_message(row), axis=1)

# # data from Gemini for KD
df_train_gemini = pd.read_csv('gemini_1_5_flash_output_train.csv', sep="\t")[['index', 'input', 'answer', 'explanation']]
df_train_gemini = pd.merge(df_train_gemini, df_train_label[['index', 'image']], on='index')
df_train_gemini['message'] = df_train_gemini.apply(lambda row: build_message(row), axis=1)

In [None]:
### val data
df_val = pd.DataFrame(load_dataset('derek-thomas/ScienceQA', split='validation'))

df_val = df_val[df_val['solution'] != ''].reset_index()
df_val['image'] = df_val.apply(lambda row: row['image'] if row['image'] else Image.new("RGB", (224, 224), (0, 0, 0)), axis=1)
df_val['input'] = df_val.apply(lambda row: build_prompt(row)[0], axis=1)
df_val['message'] = df_val.apply(lambda row: build_message(row), axis=1)

## Functions for model training

In [None]:
def preprocess_input_qwen(tokenizer, processor, prompts, texts, images, y, device):
    messages = [processor.apply_chat_template(
                text, tokenize=False, add_generation_prompt=False
    ) for text in texts]
    image_inputs, video_inputs = process_vision_info(texts)
    inputs = processor(
        text=messages,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    max_length = inputs["input_ids"].size(1)
    labels = tokenizer(y, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    return inputs.to(device, dtype=torch.bfloat16), labels.to(device)

In [None]:
def preprocess_input_paligemma(tokenizer, processor, prompts, texts, images, y, device):
    images = [np.array(image.resize((224, 224))) / 127.5 -1 for image in images]
    inputs = processor(
        text=prompts,
        images=images,
        return_tensors="pt",
        padding="longest"
    )
    labels = tokenizer(y, padding="longest", return_tensors="pt")["input_ids"]
    return inputs.to(device), labels.to(device)

In [None]:
def train(model, tokenizer, processor, optimizer, dataloader_train, dataloader_val, preprocess_func):
    train_errors = []
    val_errors = []
    model.train()
    for epoch in tqdm(range(NUM_EPOCHS_FT)):
        error = 0
        num_samples = 0
        for prompts, texts, images, y in dataloader_train:
            inputs, labels = preprocess_func(tokenizer, processor, prompts, texts, images, y, device)
            optimizer.zero_grad()
            outputs = model(
                    inputs=inputs,
                    labels=labels,
                )
            #output_ids = outputs.logits.argmax(-1)
            #output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            error += loss.item() * len(texts)
            num_samples += len(texts)
            del labels, inputs
            gc.collect()
            torch.cuda.empty_cache()
        error /= num_samples
        print(f'Error after epoch {epoch}: {error}')
        train_errors.append((epoch, error))
        if epoch % 10:
            val_error = 0
            num_samples = 0
            for prompts, texts, images, y in dataloader_val:
                inputs, labels = preprocess_func(tokenizer, processor, prompts, texts, images, y, device)
                outputs = model(
                    inputs=inputs,
                    labels=labels,
                )
                loss = outputs.loss
                val_error += loss.item() * len(texts)
                num_samples += len(texts)
                del labels, inputs
                gc.collect()
                torch.cuda.empty_cache()
            val_error /= num_samples
            print(f'Validation error after epoch {epoch}: {val_error}')
            val_errors.append((epoch, val_error))
    return train_errors, val_error

In [None]:
def visualize_error(train_errors, val_errors):
    plt.plot(zip(*train_errors), label="Train Error", marker="o", linestyle="-")
    plt.plot(zip(*val_errors), label="Train Error", marker="o", linestyle="-")
    plt.title("Train and Validation Error over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Error")
    plt.show()

## PrefixTuning using labels

In [None]:
# DataLoader for train data
dataset_label_train = PrefixDataset(df_train_label)
dataloader_label_train=DataLoader(dataset_label_train, collate_fn=prefix_collate, batch_size=BATCH_SIZE, shuffle=True)
# DataLoader for val data
dataset_label_val = PrefixDataset(df_val)
dataloader_label_val=DataLoader(dataset_label_val, collate_fn=prefix_collate, batch_size=BATCH_SIZE, shuffle=True)

### Qwen

In [None]:
model_name = "Qwen/Qwen2-VL-2B-Instruct"

model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

match_n_layer = model.config.num_hidden_layers
match_n_head = model.config.num_key_value_heads
n_embd = model.config.hidden_size // 6
model_prefix = PrefixTuningModelPastKeyValues(model, match_n_layer, match_n_head, n_embd, device).to(device)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)

In [None]:
train_errors_ft_qwen, val_errors_ft_qwen = train(model_prefix, tokenizer, processor, optimizer, dataloader_label_train, dataloader_label_val, preprocess_input_qwen)

In [None]:
visualize_error(train_errors_ft_qwen, val_errors_ft_qwen)

### Paligemma

In [None]:
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image

model_id = "google/paligemma2-3b-pt-224"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model_prefix = PrefixTuningModelPastKeyValues(model, tokenizer).to(device)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)

In [None]:
train_errors_ft_paligemma, val_errors_ft_paligemma = train(model_prefix, tokenizer, processor, optimizer, dataloader_label_train, dataloader_label_val, preprocess_input_paligemma)

In [None]:
visualize_error(train_errors_ft_paligemma, val_errors_ft_paligemma)

In [None]:
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch

model_id = "google/paligemma2-3b-pt-224"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Leaving the prompt blank for pre-trained models
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)


## Knowledge Distillation

In [None]:
# DataLoader for train data
dataset_gemini_train = PrefixDataset(df_train_gemini)
dataloader_gemini_train=DataLoader(dataset_gemini_train, collate_fn=prefix_collate, batch_size=BATCH_SIZE, shuffle=True)
# DataLoader for val data
dataset_label_val = PrefixDataset(df_val)
dataloader_label_val=DataLoader(dataset_label_val, collate_fn=prefix_collate, batch_size=BATCH_SIZE, shuffle=True)

### Qwen

In [None]:
model_name = "Qwen/Qwen2-VL-2B-Instruct"

model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

match_n_layer = model.config.num_hidden_layers
match_n_head = model.config.num_key_value_heads
n_embd = model.config.hidden_size // 6
model_prefix = PrefixTuningModelPastKeyValues(model, match_n_layer, match_n_head, n_embd, device).to(device)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)

In [None]:
train_errors_kd, val_errors_kd = train(dataset_gemini_train, dataloader_gemini_train)

In [None]:
visualize_error(train_errors_kd, val_errors_kd)

### Paligemma

In [None]:
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch

model_id = "google/paligemma2-3b-pt-224"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model_prefix = PrefixTuningModelPastKeyValues(model, tokenizer).to(device)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)

In [None]:
train_errors_kd_paligemma, val_errors_kd_paligemma = train(model_prefix, tokenizer, processor, optimizer, dataloader_gemini_train, dataloader_label_val, preprocess_input_paligemma)

In [None]:
visualize_error(train_errors_kd_paligemma, val_errors_kd_paligemma)