<a href="https://colab.research.google.com/github/donbcolab/AIE3/blob/main/Untitled24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -qU bitsandbytes datasets accelerate loralib peft transformers

In [None]:
import torch
from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, TrainingArguments, Trainer



In [None]:
# --- Configuration ---
base_model_name = "google/paligemma-3b-pt-224"
output_dir = "paligemma-cnmc-validation"
device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
# --- Load Model ---
processor = AutoProcessor.from_pretrained(base_model_name)
model = PaliGemmaForConditionalGeneration.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16
).to(device)

In [None]:
# --- Load Dataset ---
from datasets import load_dataset
ds = load_dataset("dwb2023/cnmc-leukemia-2019", split="train")
# Filter records to only include those from fold 0
ds_fold_0 = ds.filter(lambda example: example['fold'] == 0)
# Define the percentage you want to retrieve (e.g., 10%)
percentage = 0.10
# Use train_test_split to get the subset
cnmc_ds = ds_fold_0.train_test_split(test_size=percentage)["test"]
# Columns to remove
cols_remove = ["subject_id", "image_number", "cell_count", "class_label", "fold", "original_image_name", "relative_file_path"]
cnmc_ds = cnmc_ds.remove_columns(cols_remove)

# create train test split with test_size=0.2
train_ds = cnmc_ds.train_test_split(test_size=0.2)
# create test val split
test_val_ds = train_ds["test"].train_test_split(test_size=0.5)
cnmc_ds_dict = DatasetDict({
    "train" : train_ds["train"],
    "test" : test_val_ds["test"],
    "validation" : test_val_ds["train"]
})

In [None]:
# --- Data Collation ---
def collate_fn(batch):
    print("Keys in the batch dictionary:", batch.keys())  # Debugging output

    texts = ["Are these cells healthy or cancerous?" for _ in range(len(batch['image']))]
    labels = batch['label']
    images = [Image.open(io.BytesIO(img['bytes'])).convert("RGB") for img in batch['image']]

    # Debugging: Print the shapes of inputs before passing to processor
    print(f"Texts length: {len(texts)}")
    print(f"Labels length: {len(labels)}")
    print(f"Images length: {len(images)}")

    tokens = processor(text=texts, images=images, return_tensors="pt", padding="longest")
    tokens = tokens.to(torch.bfloat16).to(device)
    return tokens

In [None]:
# --- Training Arguments ---
args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,            # Adjust as needed
    per_device_train_batch_size=4, # Reduce batch size if needed
    gradient_accumulation_steps=4,
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=5,                 # Increase logging frequency
    evaluation_strategy="steps",      # Evaluate every 'eval_steps'
    eval_steps=5,                   # Adjust as needed
    save_strategy="steps",
    save_steps=100,                 # Adjust as needed
    push_to_hub=False,                # Set to True if you want to push to Hub
    save_total_limit=1,
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False
)

# --- Trainer ---
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=cnmc_ds_dict['train'],
    eval_dataset=cnmc_ds_dict['validation'],
    data_collator=collate_fn
)

In [None]:
# --- Training ---
trainer.train()

# --- (Optional) Push to Hub ---
# If 'push_to_hub' is set to True in TrainingArguments, this will push to the Hub
# trainer.push_to_hub()