In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use only GPU 0

In [4]:
import torch

# Clear all cached memory
torch.cuda.empty_cache()

# Reset all allocated memory
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()

# If you want to reset the default CUDA device, you can specify it again
torch.cuda.set_device(0)  # Replace 0 with the desired GPU index if needed

# Print memory stats to confirm
print("CUDA memory cleared.")
print(f"Allocated memory: {torch.cuda.memory_allocated()} bytes")
print(f"Cached memory: {torch.cuda.memory_reserved()} bytes")


CUDA memory cleared.
Allocated memory: 0 bytes
Cached memory: 0 bytes


In [5]:
!nvidia-smi

Fri Jan 10 12:44:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               Off |   00000000:15:00.0 Off |                  Off |
| 30%   36C    P8              8W /  300W |      18MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000               Off |   00

In [6]:
from transformers import logging

logging.set_verbosity_error()  # Suppress warnings and info logs

In [7]:
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
import torch
# from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
import gc
import time
import pandas as pd

In [8]:
# dataset_id = "HuggingFaceM4/ChartQA"
dataset_id = "derek-thomas/ScienceQA"
#TODO: DON'T FORGET TO HAVE THE ENTIRE DATASET
train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train", "validation", "test"])

In [9]:
from PIL import Image


def get_question_text(problem):
    question = problem['question']
    return question


def get_choice_text(probelm, options):
    choices = probelm['choices']
    choice_list = []
    for i, c in enumerate(choices):
        choice_list.append("({}) {}".format(options[i], c))
    choice_txt = " ".join(choice_list)
    return choice_txt


def get_context_text(problem, use_caption):
    txt_context = problem['hint']
    img_context = problem['caption'] if use_caption else ""
    context = " ".join([txt_context, img_context]).strip()
    if context == "":
        context = "N/A"
    return context


def build_prompt(question_data, use_lecture=False, use_solution=False):
    question = get_question_text(question_data)
    choices = get_choice_text(question_data, [choice_num for choice_num in range(5)])
    hint = get_context_text(question_data, False)
    task = question_data['task']
    input_prompt = f'Question: {question}\n Task: {task}\n Choices: {choices}\n Hint: {hint}'
    if use_lecture:
        lecture = f'\n Lecture: {question_data["lecture"]}'
        input_prompt += lecture
    if use_solution and question_data["solution"]:
        solution = f'\n Solution: {question_data["solution"]}'
        input_prompt += solution
    return input_prompt

def build_message(row):
    row_input = build_prompt(row)
    image = row['image'] if row['image'] else Image.new("RGB", (224, 224), (0, 0, 0))
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": row_input },
            ],
        }
    ]
    return messages

In [10]:
train_dataset = train_dataset.filter(lambda example: example['solution']!="")
eval_dataset = eval_dataset.filter(lambda example: example['solution']!="")
test_dataset = test_dataset.filter(lambda example: (example['solution']!="") & (example['lecture']!=""))

In [11]:
train_dataset

Dataset({
    features: ['image', 'question', 'choices', 'answer', 'hint', 'task', 'grade', 'subject', 'topic', 'category', 'skill', 'lecture', 'solution'],
    num_rows: 11515
})

In [12]:
eval_dataset

Dataset({
    features: ['image', 'question', 'choices', 'answer', 'hint', 'task', 'grade', 'subject', 'topic', 'category', 'skill', 'lecture', 'solution'],
    num_rows: 3848
})

In [13]:
test_dataset

Dataset({
    features: ['image', 'question', 'choices', 'answer', 'hint', 'task', 'grade', 'subject', 'topic', 'category', 'skill', 'lecture', 'solution'],
    num_rows: 3172
})

In [14]:
train_dataset_gemini = pd.read_csv('gemini_1_5_flash_output_train.csv', sep="\t")[['index', 'input', 'answer', 'explanation']]
train_dataset_gemini['solution'] = train_dataset_gemini['explanation']
del train_dataset_gemini['explanation']
train_dataset_df = pd.DataFrame(train_dataset).reset_index()
train_dataset_gemini = pd.merge(train_dataset_gemini, train_dataset_df[['index', 'image']], on='index')

In [15]:
# train_dataset_qwen_gemini = [(sample[1]["input"], sample[1]["solution"]) for sample in train_dataset_gemini.iterrows()]
# train_dataset_qwen = [(build_message(sample), sample["solution"]) for sample in train_dataset]
# eval_dataset_qwen = [(build_message(sample), sample["solution"]) for sample in eval_dataset]
# test_dataset_qwen = [(build_message(sample), sample["solution"]) for sample in test_dataset]

In [16]:
train_dataset_paligemma_gemini = [(sample[1]["input"], sample[1]["image"], sample[1]["solution"]) for sample in train_dataset_gemini.iterrows()] # sample["input"] is the output of build_prompt
train_dataset_paligemma = [(build_prompt(sample), sample["image"], sample["solution"]) for sample in train_dataset]
eval_dataset_paligemma = [(build_prompt(sample), sample["image"], sample["solution"]) for sample in eval_dataset]
test_dataset_paligemma = [(build_prompt(sample), sample["image"], sample["solution"]) for sample in test_dataset]

In [17]:
train_dataset[0]

{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429>,
 'question': 'Which of these states is farthest north?',
 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'],
 'answer': 0,
 'hint': '',
 'task': 'closed choice',
 'grade': 'grade2',
 'subject': 'social science',
 'topic': 'geography',
 'category': 'Geography',
 'skill': 'Read a map: cardinal directions',
 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.',
 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}

In [18]:
def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


clear_memory()

GPU allocated memory: 0.00 GB
GPU reserved memory: 0.00 GB


In [19]:
from transformers import BitsAndBytesConfig
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration

model_id = "google/paligemma2-3b-pt-224"

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16)
processor = PaliGemmaProcessor.from_pretrained(model_id)
tokenizer = processor.tokenizer

2025-01-10 12:45:36.309214: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-10 12:45:36.322251: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1736509536.338399  783314 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1736509536.343458  783314 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-10 12:45:36.360485: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [20]:
from peft import LoraConfig, get_peft_model

# Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

# Apply PEFT model adaptation
peft_model = get_peft_model(model, peft_config)

# Print trainable parameters
peft_model.print_trainable_parameters()

trainable params: 2,592,768 || all params: 3,034,835,184 || trainable%: 0.0854


In [21]:
#TODO: SET EXPERIMENTS IN A LOOP AND MAKE IT RUN BEFORE THE FLIGHT

### -> both qwen and paligemma for the normal "label" data and the gemini data please

In [22]:
from trl import SFTConfig

# Configure training arguments
training_args = SFTConfig(
    output_dir="LORA-Paligemma-ScienceQA",  # Directory to save the model
    num_train_epochs=10,  # Number of training epochs
    per_device_train_batch_size=4,  # Batch size for training
    per_device_eval_batch_size=4,  # Batch size for evaluation
    gradient_accumulation_steps=8,  # Steps to accumulate gradients
    gradient_checkpointing=True,  # Enable gradient checkpointing for memory efficiency
    # Optimizer and scheduler settings
    optim="adamw_torch_fused",  # Optimizer type
    learning_rate=2e-4,  # Learning rate for training
    lr_scheduler_type="constant",  # Type of learning rate scheduler
    # Logging and evaluation
    logging_steps=10,  # Steps interval for logging
    eval_steps=10,  # Steps interval for evaluation
    eval_strategy="steps",  # Strategy for evaluation
    save_strategy="steps",  # Strategy for saving the model
    save_steps=20,  # Steps interval for saving
    metric_for_best_model="eval_loss",  # Metric to evaluate the best model
    greater_is_better=False,  # Whether higher metric values are better
    load_best_model_at_end=True,  # Load the best model after training
    # Mixed precision and gradient settings
    bf16=True,  # Use bfloat16 precision
    tf32=True,  # Use TensorFloat-32 precision
    max_grad_norm=0.3,  # Maximum norm for gradient clipping
    warmup_ratio=0.03,  # Ratio of total steps for warmup
    # Hub and reporting
    push_to_hub=False,  # Whether to push model to Hugging Face Hub
    report_to="wandb",  # Reporting tool for tracking metrics
    # Gradient checkpointing settings
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Options for gradient checkpointing
    # Dataset configuration
    dataset_text_field="",  # Text field in dataset
    dataset_kwargs={"skip_prepare_dataset": True},  # Additional dataset options
    # max_seq_length=1024  # Maximum sequence length for input
)

training_args.remove_unused_columns = False  # Keep unused columns in dataset
training_args.eval_strategy = "epoch"
training_args.save_strategy = "epoch"

In [23]:
import wandb

wandb.init(
    project="LORA-Paligemma-ScienceQA",  # change this
    name="LORA-Paligemma-ScienceQA",  # change this
    config=training_args,
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmatyashpr[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [24]:
# Create a data collator to encode text and image pairs
def collate_fn_qwen(examples):
    
    # Get the texts and images, and apply the chat template
    texts = [
        processor.apply_chat_template(example, tokenize=False) for (example,_) in examples
    ]  # Prepare texts for processing
    image_inputs = [process_vision_info(example)[0] for (example,_) in examples]  # Process the images to extract inputs

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, padding="longest", return_tensors="pt"
    ) 
    max_length = batch["input_ids"].size(1)
    example_labels = [label for (x, label) in examples]
    labels = tokenizer(example_labels, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    batch["labels"] = labels  # Add labels to the batch
    return batch  # Return the prepared batch

In [25]:
# Create a data collator to encode text and image pairs
def collate_fn_paligemma(examples):
    texts = [text for (text, image, label) in examples]
    image_inputs = [image.resize((224, 224)) if image else Image.new("RGB", (224, 224), (0, 0, 0)) for (text, image, label) in examples]

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, padding="longest", return_tensors="pt"
    )
    max_length = batch["input_ids"].size(1)
    example_labels = [label for (text, image, label) in examples]
    labels = processor.tokenizer(example_labels, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    batch["labels"] = labels  # Add labels to the batch
    return batch  # Return the prepared batch

In [26]:
from trl import SFTTrainer


trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_paligemma,
    eval_dataset=eval_dataset_paligemma,
    data_collator=collate_fn_paligemma,
    peft_config=peft_config,
    tokenizer=tokenizer,
)

  trainer = SFTTrainer(


In [27]:
trainer.train()

  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


{'loss': 150.7346, 'grad_norm': 245.60951232910156, 'learning_rate': 0.0002, 'epoch': 0.027787426189649182}
{'loss': 67.5996, 'grad_norm': 56.05059814453125, 'learning_rate': 0.0002, 'epoch': 0.055574852379298365}
{'loss': 42.0873, 'grad_norm': 47.7778434753418, 'learning_rate': 0.0002, 'epoch': 0.08336227856894755}
{'loss': 34.7834, 'grad_norm': 24.753293991088867, 'learning_rate': 0.0002, 'epoch': 0.11114970475859673}
{'loss': 30.4227, 'grad_norm': 48.8653450012207, 'learning_rate': 0.0002, 'epoch': 0.13893713094824592}
{'loss': 26.4546, 'grad_norm': 147.55799865722656, 'learning_rate': 0.0002, 'epoch': 0.1667245571378951}
{'loss': 23.7438, 'grad_norm': 71.36222839355469, 'learning_rate': 0.0002, 'epoch': 0.1945119833275443}
{'loss': 22.6143, 'grad_norm': 33.79683303833008, 'learning_rate': 0.0002, 'epoch': 0.22229940951719346}
{'loss': 21.2488, 'grad_norm': 320.13128662109375, 'learning_rate': 0.0002, 'epoch': 0.25008683570684265}
{'loss': 19.7692, 'grad_norm': 108.7056655883789, 'l

  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


{'loss': 15.0793, 'grad_norm': 81.245361328125, 'learning_rate': 0.0002, 'epoch': 1.002778742618965}
{'loss': 13.463, 'grad_norm': 59.241058349609375, 'learning_rate': 0.0002, 'epoch': 1.030566168808614}
{'loss': 13.4047, 'grad_norm': 284.3699645996094, 'learning_rate': 0.0002, 'epoch': 1.0583535949982632}
{'loss': 13.2811, 'grad_norm': 96.35693359375, 'learning_rate': 0.0002, 'epoch': 1.0861410211879126}
{'loss': 13.4356, 'grad_norm': 82.69366455078125, 'learning_rate': 0.0002, 'epoch': 1.1139284473775617}
{'loss': 13.6338, 'grad_norm': 460.0487060546875, 'learning_rate': 0.0002, 'epoch': 1.1417158735672108}
{'loss': 13.5995, 'grad_norm': 59.260372161865234, 'learning_rate': 0.0002, 'epoch': 1.16950329975686}
{'loss': 13.0844, 'grad_norm': 72.29168701171875, 'learning_rate': 0.0002, 'epoch': 1.1972907259465093}
{'loss': 13.0267, 'grad_norm': 50.122657775878906, 'learning_rate': 0.0002, 'epoch': 1.2250781521361584}
{'loss': 13.948, 'grad_norm': 165.10292053222656, 'learning_rate': 0.00

  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


{'loss': 12.3078, 'grad_norm': 89.18677520751953, 'learning_rate': 0.0002, 'epoch': 2.00555748523793}
{'loss': 11.8662, 'grad_norm': 58.191749572753906, 'learning_rate': 0.0002, 'epoch': 2.0333449114275792}
{'loss': 12.388, 'grad_norm': 58.572391510009766, 'learning_rate': 0.0002, 'epoch': 2.061132337617228}
{'loss': 13.1126, 'grad_norm': 457.8995056152344, 'learning_rate': 0.0002, 'epoch': 2.0889197638068775}
{'loss': 12.9864, 'grad_norm': 144.6287384033203, 'learning_rate': 0.0002, 'epoch': 2.1167071899965264}
{'loss': 12.2096, 'grad_norm': 107.73517608642578, 'learning_rate': 0.0002, 'epoch': 2.1444946161861758}
{'loss': 11.5676, 'grad_norm': 43.075313568115234, 'learning_rate': 0.0002, 'epoch': 2.172282042375825}
{'loss': 11.088, 'grad_norm': 31.716068267822266, 'learning_rate': 0.0002, 'epoch': 2.200069468565474}
{'loss': 12.1889, 'grad_norm': 29.80038070678711, 'learning_rate': 0.0002, 'epoch': 2.2278568947551234}
{'loss': 11.2747, 'grad_norm': 85.05868530273438, 'learning_rate':

  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


{'loss': 11.9082, 'grad_norm': 30.381389617919922, 'learning_rate': 0.0002, 'epoch': 3.0083362278568946}
{'loss': 11.5827, 'grad_norm': 80.5597152709961, 'learning_rate': 0.0002, 'epoch': 3.036123654046544}
{'loss': 10.0318, 'grad_norm': 130.72198486328125, 'learning_rate': 0.0002, 'epoch': 3.0639110802361933}
{'loss': 10.469, 'grad_norm': 33.25981903076172, 'learning_rate': 0.0002, 'epoch': 3.091698506425842}
{'loss': 11.4779, 'grad_norm': 22.58572006225586, 'learning_rate': 0.0002, 'epoch': 3.1194859326154916}
{'loss': 10.1526, 'grad_norm': 112.7008285522461, 'learning_rate': 0.0002, 'epoch': 3.1472733588051405}
{'loss': 10.3681, 'grad_norm': 29.747241973876953, 'learning_rate': 0.0002, 'epoch': 3.17506078499479}
{'loss': 10.4871, 'grad_norm': 85.02025604248047, 'learning_rate': 0.0002, 'epoch': 3.202848211184439}
{'loss': 10.4419, 'grad_norm': 57.546329498291016, 'learning_rate': 0.0002, 'epoch': 3.230635637374088}
{'loss': 10.8694, 'grad_norm': 24.745731353759766, 'learning_rate': 

  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


{'loss': 10.1289, 'grad_norm': 65.82814025878906, 'learning_rate': 0.0002, 'epoch': 4.01111497047586}
{'loss': 9.7837, 'grad_norm': 44.156494140625, 'learning_rate': 0.0002, 'epoch': 4.038902396665509}
{'loss': 9.9447, 'grad_norm': 83.00977325439453, 'learning_rate': 0.0002, 'epoch': 4.0666898228551585}
{'loss': 9.1739, 'grad_norm': 43.1068229675293, 'learning_rate': 0.0002, 'epoch': 4.094477249044807}
{'loss': 9.3519, 'grad_norm': 83.94661712646484, 'learning_rate': 0.0002, 'epoch': 4.122264675234456}
{'loss': 9.1861, 'grad_norm': 34.68849182128906, 'learning_rate': 0.0002, 'epoch': 4.150052101424105}
{'loss': 9.6047, 'grad_norm': 59.67850112915039, 'learning_rate': 0.0002, 'epoch': 4.177839527613755}
{'loss': 9.7801, 'grad_norm': 74.53050231933594, 'learning_rate': 0.0002, 'epoch': 4.205626953803404}
{'loss': 9.8157, 'grad_norm': 67.88227844238281, 'learning_rate': 0.0002, 'epoch': 4.233414379993053}
{'loss': 10.105, 'grad_norm': 685.7171630859375, 'learning_rate': 0.0002, 'epoch': 4

KeyboardInterrupt: 

In [28]:
trainer.save_model()