## Imports & Setup

In [11]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from tensorflow.python.keras.backend import dtype
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import re
import ast
from datasets import load_dataset
from tqdm import tqdm
#import helpers
#import prefix_tuning
import importlib
from PIL import Image
from qwen_vl_utils import process_vision_info
#importlib.reload(helpers)
#importlib.reload(prefix_tuning)

<module 'prefix_tuning' from '/Users/floriandreyer/Library/Mobile Documents/com~apple~CloudDocs/Python Projekte/foundation_models/prefix_tuning.py'>

helpers.py

In [None]:
from PIL import Image


def get_question_text(problem):
    question = problem['question']
    return question


def get_choice_text(probelm, options):
    choices = probelm['choices']
    choice_list = []
    for i, c in enumerate(choices):
        choice_list.append("({}) {}".format(options[i], c))
    choice_txt = " ".join(choice_list)
    return choice_txt


def get_context_text(problem, use_caption):
    txt_context = problem['hint']
    img_context = problem['caption'] if use_caption else ""
    context = " ".join([txt_context, img_context]).strip()
    if context == "":
        context = "N/A"
    return context


def build_prompt(question_data, use_lecture=False, use_solution=False):
    question = get_question_text(question_data)
    choices = get_choice_text(question_data, [choice_num for choice_num in range(5)])
    hint = get_context_text(question_data, False)
    #image = question_data['image']
    task = question_data['task']
    input_prompt = f'Question: {question}\n Task: {task}\n Choices: {choices}\n Hint: {hint}'
    if use_lecture:
        lecture = f'\n Lecture: {question_data["lecture"]}'
        input_prompt += lecture
    if use_solution and question_data["solution"]:
        solution = f'\n Solution: {question_data["solution"]}'
        input_prompt += solution
    prompt = [input_prompt]
    #if image:
    #    prompt.append(image)
    return prompt

def build_message(row):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": row["image"],
                },
                {"type": "text", "text": row['input']},
            ],
        }
    ]
    return messages

In [None]:
import torch
from torch import nn

class PrefixTuning(nn.Module):
    def __init__(self, config, prefix_length=10):
        super().__init__()
        self.prefix_length = prefix_length
        self.hidden_size = config.hidden_size
        self.prefix_embeddings = nn.Parameter(torch.randn(prefix_length, config.hidden_size))

    def forward(self, inputs_embeds):
        batch_size = inputs_embeds.size(0)
        prefix = self.prefix_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        return torch.cat([prefix, inputs_embeds], dim=1)


class PrefixTuningModel(nn.Module):
    def __init__(self, model, tokenizer, prefix_length=10):
        super().__init__()
        self.model = model
        self.freeze_main_model()
        self.tokenizer = tokenizer
        self.prefix_tuning = PrefixTuning(self.model.config, prefix_length)

    def freeze_main_model(self):
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, inputs, labels):
        inputs_embeds = self.model.get_input_embeddings()(inputs["input_ids"])
        # Add Prefix
        inputs_embeds = self.prefix_tuning(inputs_embeds)

        # Modify attention mask for prefix
        prefix_mask = torch.ones((inputs["input_ids"].size(0), self.prefix_tuning.prefix_length), device=inputs["input_ids"].device)
        attention_mask = torch.cat([prefix_mask, inputs["attention_mask"]], dim=1)

        return self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, pixel_values=inputs["pixel_values"], labels=labels)

In [2]:
#device = torch.device('mps' if (torch.backends.mps.is_available() and torch.backends.mps.is_built()) else 'cpu')
device = torch.device('cpu')

# CONSTANTS
NUM_EPOCHS_FT = 100
NUM_EPOCHS_KD = 100
BATCH_SIZE = 32

In [3]:
device

device(type='cpu')

## Get Data and preprocess it

In [4]:
### train data
# data with label and image data
df_train_label = pd.DataFrame(load_dataset('derek-thomas/ScienceQA', split='train'))

df_train_label = df_train_label[df_train_label['solution'] != ''].reset_index()
df_train_label['image'] = df_train_label.apply(lambda row: row['image'] if row['image'] else Image.new("RGB", (224, 224), (0, 0, 0)), axis=1)
df_train_label['input'] = df_train_label.apply(lambda row: helpers.build_prompt(row)[0], axis=1)
df_train_label['message'] = df_train_label.apply(lambda row: helpers.build_message(row), axis=1)

# # data from Gemini for KD
df_train_gemini = pd.read_csv('gemini_1_5_flash_output_train.csv', sep="\t")[['index', 'input', 'answer', 'explanation']]
df_train_gemini = pd.merge(df_train_gemini, df_train_label[['index', 'image']], on='index')
df_train_gemini['message'] = df_train_gemini.apply(lambda row: helpers.build_message(row), axis=1)

In [5]:
### val data
df_val = pd.DataFrame(load_dataset('derek-thomas/ScienceQA', split='validation'))
df_val['image'] = df_val.apply(lambda row: row['image'] if row['image'] else Image.new("RGB", (224, 224), (0, 0, 0)), axis=1)
df_val['input'] = df_val.apply(lambda row: helpers.build_prompt(row)[0], axis=1)
df_val['message'] = df_val.apply(lambda row: helpers.build_message(row), axis=1)

## Functions for model training

In [15]:
def train(model, tokenizer, processor, optimizer, dataloader_train, dataloader_val):
    train_errors = []
    val_errors = []
    model.train()
    for epoch in tqdm(range(NUM_EPOCHS_FT)):
        error = 0
        num_samples = 0
        for texts, images, y in dataloader_train:
            messages = [processor.apply_chat_template(
                    text, tokenize=False, add_generation_prompt=False
            ) for text in texts]
            image_inputs, video_inputs = process_vision_info(texts)
            inputs = processor(
                text=messages,
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            labels = tokenizer(y, padding=True, return_tensors="pt")["input_ids"].tolist()
            max_length = inputs["input_ids"].size(1) + 10 # +10 for later prefix
            for i in range(len(labels)):
                labels[i] += [tokenizer.pad_token_id] * (max_length - len(labels[i]))
            labels = torch.tensor(labels, dtype=torch.long)
            optimizer.zero_grad()
            outputs = model(inputs, labels=labels)
            #output_ids = outputs.logits.argmax(-1)
            #output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            error += loss.item() * len(texts)
            num_samples += len(texts)
        error /= num_samples
        print(f'Error after epoch {epoch}: {error}')
        train_errors.append((epoch, error))
        if epoch % 10:
            val_error = 0
            num_samples = 0
            for texts, images, y in dataloader_val:
                labels = tokenizer(y, padding=True, return_tensors="pt")
                messages = [processor.apply_chat_template(
                    text, tokenize=False, add_generation_prompt=True
                ) for text in texts]
                image_inputs, video_inputs = process_vision_info(texts)
                inputs = processor(
                    text=messages,
                    images=image_inputs,
                    videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                outputs = model(
                    inputs=inputs,
                    labels=labels,
                )
                loss = outputs.loss
                val_error += loss.item() * len(texts)
                num_samples += len(texts)
            val_error /= num_samples
            print(f'Validation error after epoch {epoch}: {val_error}')
            val_errors.append((epoch, val_error))
    return train_errors_ft, val_error

In [7]:
def visualize_error(train_errors, val_errors):
    plt.plot(zip(*train_errors), label="Train Error", marker="o", linestyle="-")
    plt.plot(zip(*val_errors), label="Train Error", marker="o", linestyle="-")
    plt.title("Train and Validation Error over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Error")
    plt.show()

In [8]:
class PrefixDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        return row['message'], row['image'], row['solution']

In [9]:
def prefix_collate(batch):
    message, image, y = zip(*batch)
    return message, image, y

## PrefixTuning using labels

In [13]:
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration

model_name = "Qwen/Qwen2-VL-2B-Instruct"

model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype="float32",
    device_map={"": "cpu"}
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

model_prefix = prefix_tuning.PrefixTuningModel(model, tokenizer)
optimizer = torch.optim.AdamW(model_prefix.prefix_tuning.parameters(), lr=5e-5)
# DataLoader for train data
dataset_label_train = PrefixDataset(df_train_label)
dataloader_label_train=DataLoader(dataset_label_train, collate_fn=prefix_collate, batch_size=32, shuffle=True)
# DataLoader for val data
dataset_label_val = PrefixDataset(df_val)
dataloader_label_val=DataLoader(dataset_label_val, collate_fn=prefix_collate, batch_size=32, shuffle=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
train_errors_ft, val_errors_ft = train(model_prefix, tokenizer, processor, optimizer, dataloader_label_train, dataloader_label_val)

  0%|          | 0/100 [00:00<?, ?it/s]

Input_ids dims: torch.Size([3, 125])
Input_ids dims: torch.Size([3, 462])


  0%|          | 0/100 [01:26<?, ?it/s]


KeyboardInterrupt: 

In [None]:
visualize_error(train_errors_ft, val_errors_ft)

## Knowledge Distillation

In [None]:
model_knowledge_distillation = soft_prompting.MultimodalSoftPrompting.from_pretrained(model)
# DataLoader for train data
dataset_gemini_train = SoftPromptingDataset(df_train_gemini, model_fine_tuned)
dataloader_gemini_train=DataLoader(dataset_gemini_train, batch_size=32, shuffle=True)
# DataLoader for val data
dataset_gemini_val = SoftPromptingDataset(df_val, model_fine_tuned)
dataloader_gemini_val=DataLoader(dataset_gemini_val, batch_size=32, shuffle=True)

In [None]:
train_errors_kd, val_errors_kd = train(dataset_gemini_train, dataloader_gemini_train)

In [None]:
visualize_error(train_errors_kd, val_errors_kd)

In [None]:
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="float32", device_map={"": "cpu"}
)

# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2-VL-2B-Instruct",
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     device_map="auto",
# )

In [None]:
def tokenize_dataset(df, tokenizer, input_column="input"):

    tokenized_data = []
    for _, row in df.iterrows():
        sample = tokenizer.encode(row[input_column], return_tensors="pt").squeeze(0)
        tokenized_data.append(sample)
    df["input_ids"] = tokenized_data
    return df

In [None]:
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": df_train_label.iloc[10]['image'],
            },
            {"type": "text", "text": df_train_label.iloc[10]['question'] + " " + ' '.join(df_train_label.iloc[10]['choices'])},
        ],
    }
]
# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to(device)

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

In [None]:
"""model_name = "Qwen2-VL-2B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)


# Datensatz tokenisieren
tokenized_data = tokenize_dataset(df_train_gemini, tokenizer, input_column="input", label_column="answer")

# Dataset erstellen
dataset = SoftPromptingDataset(tokenized_data)

# Zugriff auf ein Beispiel
print(dataset[0])"""