In [None]:
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
import torch
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor
import gc
import time
import pandas as pd

In [None]:
dataset_id = "derek-thomas/ScienceQA"
train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train", "validation", "test"])

In [None]:
from PIL import Image


def get_question_text(problem):
    question = problem['question']
    return question


def get_choice_text(probelm, options):
    choices = probelm['choices']
    choice_list = []
    for i, c in enumerate(choices):
        choice_list.append("({}) {}".format(options[i], c))
    choice_txt = " ".join(choice_list)
    return choice_txt


def get_context_text(problem, use_caption):
    txt_context = problem['hint']
    img_context = problem['caption'] if use_caption else ""
    context = " ".join([txt_context, img_context]).strip()
    if context == "":
        context = "N/A"
    return context


def build_prompt(question_data, use_lecture=False, use_solution=False):
    question = get_question_text(question_data)
    choices = get_choice_text(question_data, [choice_num for choice_num in range(5)])
    hint = get_context_text(question_data, False)
    task = question_data['task']
    input_prompt = f'Question: {question}\n Task: {task}\n Choices: {choices}\n Hint: {hint}'
    if use_lecture:
        lecture = f'\n Lecture: {question_data["lecture"]}'
        input_prompt += lecture
    if use_solution and question_data["solution"]:
        solution = f'\n Solution: {question_data["solution"]}'
        input_prompt += solution
    return input_prompt

def build_message(row):
    row_input = build_prompt(row)
    image = row['image'] if row['image'] else Image.new("RGB", (224, 224), (0, 0, 0))
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": row_input },
            ],
        }
    ]
    return messages

In [None]:
train_dataset = train_dataset.filter(lambda example: example['solution']!="")
eval_dataset = eval_dataset.filter(lambda example: example['solution']!="")
test_dataset = test_dataset.filter(lambda example: (example['solution']!="") & (example['lecture']!=""))

In [None]:
train_dataset_gemini = pd.read_csv('gemini_1_5_flash_output_train.csv', sep="\t")[['index', 'input', 'answer', 'explanation']]
train_dataset_df = pd.DataFrame(train_dataset).reset_index()
train_dataset_gemini = pd.merge(train_dataset_gemini, train_dataset_df[['index', 'image', 'solution']], on='index')

In [None]:
train_dataset_qwen_gemini = [(sample[1]["input"], sample[1]["solution"]) for sample in train_dataset_gemini.iterrows()]
train_dataset_qwen = [(build_message(sample), sample["solution"]) for sample in train_dataset]
eval_dataset_qwen = [(build_message(sample), sample["solution"]) for sample in eval_dataset]
test_dataset_qwen = [(build_message(sample), sample["solution"]) for sample in test_dataset]

In [None]:
train_dataset_paligemma_gemini = [(sample[1]["input"], sample[1]["image"], sample[1]["solution"]) for sample in train_dataset_gemini.iterrows()] # sample["input"] is the output of build_prompt
train_dataset_paligemma = [(build_prompt(sample), sample["image"], sample["solution"]) for sample in train_dataset]
eval_dataset_paligemma = [(build_prompt(sample), sample["image"], sample["solution"]) for sample in eval_dataset]
test_dataset_paligemma = [(build_prompt(sample), sample["image"], sample["solution"]) for sample in test_dataset]

In [None]:
def collate_fn_qwen(examples):

    # Get the texts and images, and apply the chat template
    texts = [
        processor.apply_chat_template(example, tokenize=False) for (example,_) in examples
    ]  # Prepare texts for processing
    image_inputs = [process_vision_info(example)[0] for (example,_) in examples]  # Process the images to extract inputs

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, return_tensors="pt", padding=True
    )
    max_length = batch["input_ids"].size(1)
    example_labels = [label for (x, label) in examples]
    labels = tokenizer(example_labels, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    batch["labels"] = labels  # Add labels to the batch
    return batch  # Return the prepared batch

In [None]:
def collate_fn_paligemma(examples):
    texts = [text for (text, image, label) in examples]
    image_inputs = [image.resize((224, 224)) for (text, image, label) in examples]

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, return_tensors="pt", padding=True
    )
    max_length = batch["input_ids"].size(1)
    example_labels = [label for (text, image, label) in examples]
    labels = processor.tokenizer(example_labels, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    batch["labels"] = labels  # Add labels to the batch
    return batch  # Return the prepared batch

### Qwen

In [None]:
model_name = "Qwen/Qwen2-VL-2B-Instruct"

model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

In [None]:
from peft import PrefixTuningConfig, get_peft_model

# Configure LoRA
peft_config = PrefixTuningConfig(
    task_type="CAUSAL_LM",
    num_virtual_tokens=30,
)

# Apply PEFT model adaptation
peft_model = get_peft_model(model, peft_config)

# Print trainable parameters
peft_model.print_trainable_parameters()

In [None]:
from trl import SFTConfig

# Configure training arguments
training_args = SFTConfig(
    output_dir="Qwen/Qwen2-VL-2B-Instruct-ScienceQA",  # Directory to save the model
    num_train_epochs=20,  # Number of training epochs
    per_device_train_batch_size=4,  # Batch size for training
    per_device_eval_batch_size=4,  # Batch size for evaluation
    gradient_accumulation_steps=8,  # Steps to accumulate gradients
    gradient_checkpointing=True,  # Enable gradient checkpointing for memory efficiency
    # Optimizer and scheduler settings
    optim="adamw_torch_fused",  # Optimizer type
    learning_rate=2e-4,  # Learning rate for training
    lr_scheduler_type="constant",  # Type of learning rate scheduler
    # Logging and evaluation
    logging_steps=10,  # Steps interval for logging
    eval_steps=10,  # Steps interval for evaluation
    eval_strategy="steps",  # Strategy for evaluation
    save_strategy="steps",  # Strategy for saving the model
    save_steps=20,  # Steps interval for saving
    metric_for_best_model="eval_loss",  # Metric to evaluate the best model
    greater_is_better=False,  # Whether higher metric values are better
    load_best_model_at_end=True,  # Load the best model after training
    # Mixed precision and gradient settings
    bf16=True,  # Use bfloat16 precision
    tf32=True,  # Use TensorFloat-32 precision
    max_grad_norm=0.3,  # Maximum norm for gradient clipping
    warmup_ratio=0.03,  # Ratio of total steps for warmup
    # Hub and reporting
    push_to_hub=False,  # Whether to push model to Hugging Face Hub
    report_to="wandb",  # Reporting tool for tracking metrics
    # Gradient checkpointing settings
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Options for gradient checkpointing
    # Dataset configuration
    dataset_text_field="",  # Text field in dataset
    dataset_kwargs={"skip_prepare_dataset": True},  # Additional dataset options
    # max_seq_length=1024  # Maximum sequence length for input
)

training_args.remove_unused_columns = False  # Keep unused columns in dataset

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_qwen,
    eval_dataset=eval_dataset_qwen,
    data_collator=collate_fn_qwen,
    peft_config=peft_config,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()