<a href="https://colab.research.google.com/github/donbcolab/AIE3/blob/main/paligemma_cnmc_finetune_v7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Paligemma Fine Tuning using CNMC dataset

### Setting Up

In [None]:
base_model_name = "google/paligemma-3b-pt-224"
adapter_version = "paligemma-cnmc-ft"
adapter_model_name = f"dwb2023/{adapter_version}"

In [None]:
!pip install -q -U git+https://github.com/huggingface/transformers.git datasets accelerate bitsandbytes peft hf_transfer

In [None]:
import os
from google.colab import userdata

HF_TOKEN = userdata.get('HF_TOKEN')
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

## Load Dataset

In [None]:
from datasets import load_dataset, DatasetDict, Image

# Load CNMC Dataset
ds = load_dataset("dwb2023/cnmc-leukemia-2019", split="train")#.cast_column("image", Image(decode=True))

In [None]:
ds[0]

In [None]:
# Filter records to only include those from fold 0
ds_fold_0 = ds.filter(lambda example: example['fold'] == 0)

# Define the percentage you want to retrieve (e.g., 10%)
percentage = 0.10

# Use train_test_split to get the subset
cnmc_ds = ds_fold_0.train_test_split(test_size=percentage)["test"]

# Columns to remove
cols_remove = ["subject_id", "image_number", "fold", "original_image_name", "relative_file_path"]
cnmc_ds = cnmc_ds.remove_columns(cols_remove)

In [None]:
# create train test split with test_size=0.2
train_ds = cnmc_ds.train_test_split(test_size=0.2)

# create test val split
test_val_ds = train_ds["test"].train_test_split(test_size=0.5)

cnmc_ds_dict = DatasetDict({
    "train" : train_ds["train"],
    "test" : test_val_ds["test"],
    "validation" : test_val_ds["train"]
})

cnmc_ds_dict

## Collate Data

In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(base_model_name)

In [None]:
max_seq_length = 128
output_dir = adapter_version

In [None]:
import torch
device = "cuda"

def collate_fn(examples):
  texts = ["Are these cells healthy or cancerous?" for _ in range(len(examples))]
  labels = [example['label'] for example in examples]
  images = [example["image"].convert("RGB") for example in examples]
  tokens = processor(text=texts, images=images, suffix=labels,
                  return_tensors="pt", padding="longest")

  tokens = tokens.to(torch.bfloat16).to(device)
  return tokens

## Load and Quatize the base Model (bitsandbytes)

In [None]:
import torch

from transformers import PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_name, quantization_config=bnb_config, device_map={"":0})
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


## Train the Adapter Model (trl)

In [None]:
model

In [None]:
model.config

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir=adapter_version,
    num_train_epochs=20,  # Reduced to 1 for quicker demonstration
    remove_unused_columns=False,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,  # Reduced to speed up training
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=100,  # More frequent logging
    optim="adamw_hf",
    save_strategy="epoch",
    push_to_hub=True,
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False,
    load_best_model_at_end=True,  # Required for EarlyStoppingCallback
    evaluation_strategy="epoch",  # Set to epoch for periodic evaluation
)


In [None]:
ds_train = cnmc_ds_dict["train"].cast_column("image", Image(decode=True))
ds_eval = cnmc_ds_dict["validation"].cast_column("image", Image(decode=True))

In [None]:
ds_train[0]

In [None]:
ds_eval[0]

In [None]:
from transformers import Trainer, EarlyStoppingCallback

# Define EarlyStoppingCallback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=1,
    early_stopping_threshold=0.01,
)

# Define Trainer with EarlyStoppingCallback
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_train,
    eval_dataset=ds_eval,
    data_collator=collate_fn,
    callbacks=[early_stopping]
)


In [None]:
# Start fine-tuning
trainer.train()