<a href="https://colab.research.google.com/github/donbcolab/AIE3/blob/main/paligemma_cnmc_finetune_v9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Paligemma Fine Tuning using CNMC dataset

### Setting Up

In [1]:
# Constants
base_model_name = "google/paligemma-3b-pt-224"
adapter_version = "paligemma-cnmc-ft"
adapter_model_name = f"dwb2023/{adapter_version}"
max_seq_length = 128
output_dir = adapter_version

In [2]:
!pip install -q -U git+https://github.com/huggingface/transformers.git bitsandbytes datasets accelerate peft hf_transfer

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m69.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 16.1.0 which is incompatible.
google-colab 1.0.0 requires requests==2.31.0, but you have requests 2.32.3 which is incompatible.
ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 16.1.0 which is incompatible.[0m[31m
[0m

In [3]:
from transformers import PaliGemmaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig
from transformers import AutoProcessor
import bitsandbytes as bnb
import torch

In [4]:
import os
from google.colab import userdata

HF_TOKEN = userdata.get('HF_TOKEN')
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

## Load Dataset

In [5]:
from datasets import load_dataset, DatasetDict, Image

# Load CNMC Dataset
ds = load_dataset("dwb2023/cnmc-leukemia-2019", split="train")#.cast_column("image", Image(decode=True))

Downloading readme:   0%|          | 0.00/3.07k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/92.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/91.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/88.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/72.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/87.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/92.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/99.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/83.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/78.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/83.1M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/58.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/94.1M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/77.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10661 [00:00<?, ? examples/s]

In [6]:
# Filter records to only include those from fold 0
ds_fold_0 = ds.filter(lambda example: example['fold'] == 0)

# Define the percentage you want to retrieve (e.g., 10%)
percentage = 0.10

# Use train_test_split to get the subset
cnmc_ds = ds_fold_0.train_test_split(test_size=percentage)["test"]

# Columns to remove
cols_remove = ["subject_id", "image_number", "fold", "original_image_name", "relative_file_path"]
cnmc_ds = cnmc_ds.remove_columns(cols_remove)

Filter:   0%|          | 0/10661 [00:00<?, ? examples/s]

In [7]:
# create train test split with test_size=0.2
train_ds = cnmc_ds.train_test_split(test_size=0.2)

# create test val split
test_val_ds = train_ds["test"].train_test_split(test_size=0.5)

cnmc_ds_dict = DatasetDict({
    "train" : train_ds["train"],
    "test" : test_val_ds["test"],
    "validation" : test_val_ds["train"]
})

cnmc_ds_dict

DatasetDict({
    train: Dataset({
        features: ['cell_count', 'image', 'label', 'class_label'],
        num_rows: 282
    })
    test: Dataset({
        features: ['cell_count', 'image', 'label', 'class_label'],
        num_rows: 36
    })
    validation: Dataset({
        features: ['cell_count', 'image', 'label', 'class_label'],
        num_rows: 35
    })
})

## Collate Data

In [8]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(base_model_name)

preprocessor_config.json:   0%|          | 0.00/699 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.26M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

In [9]:
import torch
device = "cuda"

def collate_fn(examples):
  texts = ["Are these cells healthy or cancerous?" for _ in range(len(examples))]
  labels = [example['label'] for example in examples]
  images = [example["image"].convert("RGB") for example in examples]
  tokens = processor(text=texts, images=images, suffix=labels,
                  return_tensors="pt", padding="longest", max_length=max_seq_length)

  tokens = tokens.to(torch.bfloat16).to(device)
  return tokens

In [10]:
model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_name, torch_dtype=torch.bfloat16).to(device)
# Ensure all parameters require gradients
for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.multi_modal_projector.parameters():
    param.requires_grad = True

# Disable cache for gradient checkpointing
model.config.use_cache = False

# Set compute dtype
model.bnb_4bit_compute_dtype = torch.bfloat16

# Enable gradient checkpointing if needed (currently commented out to avoid the error)
model.gradient_checkpointing_enable()

config.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/62.6k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.74G [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

## Load and Quatize the base Model (bitsandbytes)

In [11]:
import torch

from transformers import PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_name, quantization_config=bnb_config, device_map={"":0})
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


Unused kwargs: ['bnb_4bit_compute_type']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 11,298,816 || all params: 2,934,765,296 || trainable%: 0.3850


## Train the Adapter Model (trl)

In [12]:
# model

In [13]:
# model.config

In [14]:
from transformers import TrainingArguments

max_seq_length = 128
output_dir = adapter_version

args=TrainingArguments(
  output_dir=output_dir,
  num_train_epochs=2,
  remove_unused_columns=False,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=4,
  warmup_steps=2,
  learning_rate=2e-5,
  weight_decay=1e-6,
  adam_beta2=0.999,
  logging_steps=100,
  optim="adamw_hf",
  save_strategy="steps",
  save_steps=1000,
  push_to_hub=True,
  save_total_limit=1,
  bf16=True,
  report_to=["tensorboard"],
  dataloader_pin_memory=False
)

In [15]:
ds_train = cnmc_ds_dict["train"].cast_column("image", Image(decode=True))
ds_eval = cnmc_ds_dict["validation"].cast_column("image", Image(decode=True))

In [16]:
# Clear CUDA cache to free up memory
# torch.cuda.empty_cache()

In [17]:
from transformers import Trainer

trainer = Trainer(
  model=model,
  train_dataset=ds_train,
  eval_dataset=ds_eval,
  data_collator=collate_fn,
  args=args
  )

In [18]:
# Start fine-tuning
trainer.train()



OutOfMemoryError: CUDA out of memory. Tried to allocate 268.00 MiB. GPU 

In [None]:
  trainer.push_to_hub(f"dwb2023/{output_dir}")