<a href="https://colab.research.google.com/github/donbcolab/AIE3/blob/main/paligemma_cnmc_finetune_v11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Paligemma Fine Tuning using CNMC dataset v10

### Setting Up

In [1]:
!pip install -q -U git+https://github.com/huggingface/transformers.git bitsandbytes datasets accelerate peft hf_transfer

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [2]:
import torch
from transformers import PaliGemmaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig
from transformers import AutoProcessor
from datasets import load_dataset, DatasetDict, Image
from google.colab import userdata
import os

# Constants
base_model_name = "google/paligemma-3b-pt-224"
adapter_version = "paligemma-cnmc-ft"
adapter_model_name = f"dwb2023/{adapter_version}"
max_seq_length = 128
output_dir = adapter_version

# HF Token
HF_TOKEN = userdata.get('HF_TOKEN')
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [3]:
# Load Dataset and Processor
ds = load_dataset("dwb2023/cnmc-leukemia-2019", split="train")
processor = AutoProcessor.from_pretrained(base_model_name)

# Filter records to only include those from fold 0
ds_fold_0 = ds.filter(lambda example: example['fold'] == 0)

# Define the percentage you want to retrieve (e.g., 10%)
percentage = 0.10

# Use train_test_split to get the subset
cnmc_ds = ds_fold_0.train_test_split(test_size=percentage)["test"]

# Columns to remove
cols_remove = ["subject_id", "image_number", "fold", "original_image_name", "relative_file_path"]
cnmc_ds = cnmc_ds.remove_columns(cols_remove)

# Create train-test split
train_ds = cnmc_ds.train_test_split(test_size=0.2)

# Create test-val split
test_val_ds = train_ds["test"].train_test_split(test_size=0.5)

cnmc_ds_dict = DatasetDict({
    "train": train_ds["train"],
    "test": test_val_ds["test"],
    "validation": test_val_ds["train"]
})

# Cast image column AFTER splitting the dataset
ds_train = cnmc_ds_dict["train"].cast_column("image", Image(decode=True))
ds_eval = cnmc_ds_dict["validation"].cast_column("image", Image(decode=True))

In [4]:
# Collate Function (optimized)
device = "cuda"

def collate_fn(examples):
    texts = ["Are these cells healthy or cancerous?" for _ in range(len(examples))]
    labels = [example['label'] for example in examples]

    # Image conversion on CPU
    images = [example["image"].convert("RGB") for example in examples]

    tokens = processor(text=texts, images=images, suffix=labels,
                       return_tensors="pt", padding="longest", max_length=max_seq_length)

    # Convert to bfloat16 before moving to GPU
    tokens = tokens.to(torch.bfloat16).to(device)
    return tokens

In [5]:
from transformers import PaliGemmaForConditionalGeneration
import torch

model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_name, torch_dtype=torch.bfloat16).to(device)

# Freeze vision tower and multimodal projector parameters
for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.multi_modal_projector.parameters():
    param.requires_grad = False


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_name, quantization_config=bnb_config, device_map={"":0})
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

Unused kwargs: ['bnb_4bit_compute_type']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 11,298,816 || all params: 2,934,765,296 || trainable%: 0.3850


In [7]:
from transformers import TrainingArguments
args=TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=100,
            remove_unused_columns=False,
            per_device_train_batch_size=4,
            gradient_accumulation_steps=4,
            warmup_steps=2,
            learning_rate=2e-5,
            weight_decay=1e-6,
            adam_beta2=0.999,
            logging_steps=100,
            optim="adamw_hf",
            save_strategy="steps",
            save_steps=1000,
            push_to_hub=True,
            save_total_limit=1,
            bf16=True,
            report_to=["tensorboard"],
            dataloader_pin_memory=False
        )

In [8]:
from transformers import Trainer

trainer = Trainer(
        model=model,
        train_dataset=ds_train,
        eval_dataset=ds_eval,
        data_collator=collate_fn,
        args=args
        )

In [9]:
trainer.train()



Step,Training Loss
100,0.371


TrainOutput(global_step=170, training_loss=0.2944864441366757, metrics={'train_runtime': 327.0602, 'train_samples_per_second': 8.622, 'train_steps_per_second': 0.52, 'total_flos': 1.0421945291225664e+16, 'train_loss': 0.2944864441366757, 'epoch': 9.577464788732394})

In [10]:
trainer.push_to_hub(f"dwb2023/{output_dir}")

CommitInfo(commit_url='https://huggingface.co/dwb2023/paligemma-cnmc-ft/commit/d4f05eca05a0affc68e5e4dc55bf3f1cbeef0b7c', commit_message='dwb2023/paligemma-cnmc-ft', commit_description='', oid='d4f05eca05a0affc68e5e4dc55bf3f1cbeef0b7c', pr_url=None, pr_revision=None, pr_num=None)