## Imports & Setup

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from tensorflow.python.keras.backend import dtype
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import re
import ast
from datasets import load_dataset
from tqdm import tqdm
#import helpers
#import prefix_tuning
import importlib
from PIL import Image
from qwen_vl_utils import process_vision_info
#importlib.reload(helpers)
#importlib.reload(prefix_tuning)

In [2]:
import torch
torch.cuda.empty_cache()

In [3]:
from PIL import Image


def get_question_text(problem):
    question = problem['question']
    return question


def get_choice_text(probelm, options):
    choices = probelm['choices']
    choice_list = []
    for i, c in enumerate(choices):
        choice_list.append("({}) {}".format(options[i], c))
    choice_txt = " ".join(choice_list)
    return choice_txt


def get_context_text(problem, use_caption):
    txt_context = problem['hint']
    img_context = problem['caption'] if use_caption else ""
    context = " ".join([txt_context, img_context]).strip()
    if context == "":
        context = "N/A"
    return context


def build_prompt(question_data, use_lecture=False, use_solution=False):
    question = get_question_text(question_data)
    choices = get_choice_text(question_data, [choice_num for choice_num in range(5)])
    hint = get_context_text(question_data, False)
    #image = question_data['image']
    task = question_data['task']
    input_prompt = f'Question: {question}\n Task: {task}\n Choices: {choices}\n Hint: {hint}'
    if use_lecture:
        lecture = f'\n Lecture: {question_data["lecture"]}'
        input_prompt += lecture
    if use_solution and question_data["solution"]:
        solution = f'\n Solution: {question_data["solution"]}'
        input_prompt += solution
    prompt = [input_prompt]
    #if image:
    #    prompt.append(image)
    return prompt

def build_message(row):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": row["image"],
                },
                {"type": "text", "text": row['input']},
            ],
        }
    ]
    return messages

In [4]:
import torch
from torch import nn
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


class PrefixTuning(nn.Module):
    def __init__(self, config, prefix_length=10):
        super().__init__()
        self.prefix_length = prefix_length
        self.hidden_size = config.hidden_size
        #P'_theta
        self.prefix_param = nn.Parameter(torch.randn(prefix_length, config.hidden_size // 2).to(device='cuda', dtype=torch.bfloat16))
        #MLP_theta
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size // 2, config.hidden_size, dtype=torch.bfloat16),
            nn.Tanh(),
            nn.Linear(config.hidden_size, config.hidden_size, dtype=torch.bfloat16)
        )

    def forward(self, inputs_embeds):
        batch_size = inputs_embeds.size(0)
        prefix = self.prefix_embeddings
        prefix = self.mlp(prefix)
        prefix = prefix.unsqueeze(0).expand(batch_size, -1, -1)
        # Note: Embeddings can be made up by the MLP + paper uses them as past_key_values.
        return torch.cat([prefix, inputs_embeds], dim=1)


class PrefixTuningModel(nn.Module):
    def __init__(self, model, tokenizer, prefix_length=10):
        super().__init__()
        self.model = model
        self.freeze_main_model()
        self.tokenizer = tokenizer
        self.prefix_tuning = PrefixTuning(self.model.config, prefix_length)

    def freeze_main_model(self):
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, inputs, labels):
        inputs_embeds = self.model.get_input_embeddings()(inputs["input_ids"])
        # Add Prefix
        inputs_embeds = self.prefix_tuning(inputs_embeds)

        # Modify attention mask for prefix
        prefix_mask = torch.ones((inputs["input_ids"].size(0), self.prefix_tuning.prefix_length), device=inputs["input_ids"].device)
        attention_mask = torch.cat([prefix_mask, inputs["attention_mask"]], dim=1)

        return self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, pixel_values=inputs["pixel_values"], labels=labels)

    """def generate(self, inputs, max_new_tokens):
        inputs_embeds = self.model.get_input_embeddings()(inputs["input_ids"])
        inputs_embeds = self.prefix_tuning(inputs_embeds)
        prefix_mask = torch.ones((inputs["input_ids"].size(0), self.prefix_tuning.prefix_length),
                                 device=inputs["input_ids"].device)
        attention_mask = torch.cat([prefix_mask, inputs["attention_mask"]], dim=1)
        return self.model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, pixel_values=inputs["pixel_values"], max_new_tokens=max_new_tokens)"""

In [5]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device("mps")
print('Using device:', device)
print()


#device = torch.device('cuda:0,1' if torch.cuda.is_available() else 'cpu')
#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

# CONSTANTS
NUM_EPOCHS_FT = 100
NUM_EPOCHS_KD = 100
BATCH_SIZE = 4

Using device: cpu



In [6]:
device

device(type='cpu')

In [7]:
!nvidia-smi

zsh:1: command not found: nvidia-smi


## Get Data and preprocess it

In [8]:
### train data
# data with label and image data
df_train_label = pd.DataFrame(load_dataset('derek-thomas/ScienceQA', split='train'))

df_train_label = df_train_label[df_train_label['solution'] != ''].reset_index()
df_train_label['image'] = df_train_label.apply(lambda row: row['image'] if row['image'] else Image.new("RGB", (224, 224), (0, 0, 0)), axis=1)
df_train_label['input'] = df_train_label.apply(lambda row: build_prompt(row)[0], axis=1)
df_train_label['message'] = df_train_label.apply(lambda row: build_message(row), axis=1)

# # data from Gemini for KD
df_train_gemini = pd.read_csv('gemini_1_5_flash_output_train.csv', sep="\t")[['index', 'input', 'answer', 'explanation']]
df_train_gemini = pd.merge(df_train_gemini, df_train_label[['index', 'image']], on='index')
df_train_gemini['message'] = df_train_gemini.apply(lambda row: build_message(row), axis=1)

In [9]:
### val data
df_val = pd.DataFrame(load_dataset('derek-thomas/ScienceQA', split='validation'))

df_val = df_val[df_val['solution'] != ''].reset_index()
df_val['image'] = df_val.apply(lambda row: row['image'] if row['image'] else Image.new("RGB", (224, 224), (0, 0, 0)), axis=1)
df_val['input'] = df_val.apply(lambda row: build_prompt(row)[0], axis=1)
df_val['message'] = df_val.apply(lambda row: build_message(row), axis=1)

## Functions for model training

In [10]:
def preprocess_input_qwen(tokenizer, processor, prompts, texts, images, y, device):
    messages = [processor.apply_chat_template(
                text, tokenize=False, add_generation_prompt=False
    ) for text in texts]
    image_inputs, video_inputs = process_vision_info(texts)
    inputs = processor(
        text=messages,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    max_length = inputs["input_ids"].size(1) + 10 # +10 for later prefix
    labels = tokenizer(y, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    return inputs.to(device), labels.to(device)

In [11]:
def preprocess_input_paligemma(tokenizer, processor, prompts, texts, images, y, device):
    images = [np.array(image.resize((224, 224))) / 127.5 -1 for image in images]
    inputs = processor(
        text=prompts,
        images=images,
        return_tensors="pt",
        padding="longest"
    )
    labels = tokenizer(y, padding="longest", return_tensors="pt")["input_ids"]
    return inputs.to(device), labels.to(device)

In [12]:
def train(model, tokenizer, processor, optimizer, dataloader_train, dataloader_val, preprocess_func):
    train_errors = []
    val_errors = []
    model.train()
    for epoch in tqdm(range(NUM_EPOCHS_FT)):
        error = 0
        num_samples = 0
        for prompts, texts, images, y in dataloader_train:
            inputs, labels = preprocess_func(tokenizer, processor, prompts, texts, images, y, device)
            optimizer.zero_grad()
            outputs = model(inputs, labels=labels)
            #output_ids = outputs.logits.argmax(-1)
            #output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            error += loss.item() * len(texts)
            num_samples += len(texts)
            del labels, inputs
            gc.collect()
            torch.cuda.empty_cache()
        error /= num_samples
        print(f'Error after epoch {epoch}: {error}')
        train_errors.append((epoch, error))
        if epoch % 10:
            val_error = 0
            num_samples = 0
            for prompts, texts, images, y in dataloader_val:
                inputs, labels = preprocess_func(tokenizer, processor, prompts, texts, images, y, device)
                outputs = model(
                    inputs=inputs,
                    labels=labels,
                )
                loss = outputs.loss
                val_error += loss.item() * len(texts)
                num_samples += len(texts)
                del labels, inputs
                gc.collect()
                torch.cuda.empty_cache()
            val_error /= num_samples
            print(f'Validation error after epoch {epoch}: {val_error}')
            val_errors.append((epoch, val_error))
    return train_errors, val_error

In [13]:
def visualize_error(train_errors, val_errors):
    plt.plot(zip(*train_errors), label="Train Error", marker="o", linestyle="-")
    plt.plot(zip(*val_errors), label="Train Error", marker="o", linestyle="-")
    plt.title("Train and Validation Error over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Error")
    plt.show()

In [14]:
class PrefixDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        return row['input'], row['message'], row['image'], row['solution']

In [15]:
def prefix_collate(batch):
    input, message, image, y = zip(*batch)
    return input, message, image, y

## PrefixTuning using labels

In [None]:
# DataLoader for train data
dataset_label_train = PrefixDataset(df_train_label)
dataloader_label_train=DataLoader(dataset_label_train, collate_fn=prefix_collate, batch_size=BATCH_SIZE, shuffle=True)
# DataLoader for val data
dataset_label_val = PrefixDataset(df_val)
dataloader_label_val=DataLoader(dataset_label_val, collate_fn=prefix_collate, batch_size=BATCH_SIZE, shuffle=True)

### Qwen

In [22]:
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration
import gc

model_name = "Qwen/Qwen2-VL-2B-Instruct"

model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

model_prefix = PrefixTuningModel(model, tokenizer).to(device)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)

`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [24]:
train_errors_ft_qwen, val_errors_ft_qwen = train(model_prefix, tokenizer, processor, optimizer, dataloader_label_train, dataloader_label_val, preprocess_input_qwen)

  0%|          | 0/100 [02:40<?, ?it/s]


KeyboardInterrupt: 

In [None]:
visualize_error(train_errors_ft_qwen, val_errors_ft_qwen)

### Paligemma

In [None]:
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch

model_id = "google/paligemma2-3b-pt-224"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model_prefix = PrefixTuningModel(model, tokenizer).to(device)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)

In [None]:
train_errors_ft_paligemma, val_errors_ft_paligemma = train(model_prefix, tokenizer, processor, optimizer, dataloader_label_train, dataloader_label_val, preprocess_input_paligemma)

In [None]:
visualize_error(train_errors_ft_paligemma, val_errors_ft_paligemma)

In [16]:
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch

model_id = "google/paligemma2-3b-pt-224"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Leaving the prompt blank for pre-trained models
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)


OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/google/paligemma2-3b-pt-224.
403 Client Error. (Request ID: Root=1-6776d825-0dcebebc236e29cf2bda920b;c83f6a1c-e043-4ccc-94f2-e377569cd004)

Cannot access gated repo for url https://huggingface.co/google/paligemma2-3b-pt-224/resolve/main/config.json.
Your request to access model google/paligemma2-3b-pt-224 is awaiting a review from the repo authors.

## Knowledge Distillation

In [None]:
# DataLoader for train data
dataset_gemini_train = PrefixDataset(df_train_gemini)
dataloader_gemini_train=DataLoader(dataset_gemini_train, collate_fn=prefix_collate, batch_size=BATCH_SIZE, shuffle=True)
# DataLoader for val data
dataset_label_val = PrefixDataset(df_val)
dataloader_label_val=DataLoader(dataset_label_val, collate_fn=prefix_collate, batch_size=BATCH_SIZE, shuffle=True)

### Qwen

In [None]:
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration
import gc

model_name = "Qwen/Qwen2-VL-2B-Instruct"

model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

model_prefix = PrefixTuningModel(model, tokenizer).to(device)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)

In [None]:
train_errors_kd, val_errors_kd = train(dataset_gemini_train, dataloader_gemini_train)

In [None]:
visualize_error(train_errors_kd, val_errors_kd)

### Paligemma

In [None]:
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch

model_id = "google/paligemma2-3b-pt-224"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model_prefix = PrefixTuningModel(model, tokenizer).to(device)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)

In [None]:
train_errors_kd_paligemma, val_errors_kd_paligemma = train(model_prefix, tokenizer, processor, optimizer, dataloader_gemini_train, dataloader_label_val, preprocess_input_paligemma)

In [None]:
visualize_error(train_errors_kd_paligemma, val_errors_kd_paligemma)