<a href="https://colab.research.google.com/github/donbcolab/AIE3/blob/main/paligemma_cnmc_finetune_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Paligemma Fine Tuning using CNMC dataset

### Setting Up

In [1]:
base_model_name = "google/paligemma-3b-pt-224"
adapter_version = "paligemma-cnmc-ft"
adapter_model_name = f"dwb2023/{adapter_version}"

In [2]:
!pip install -qU bitsandbytes datasets accelerate loralib peft transformers trl

In [3]:
import io
import os
from PIL import Image
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import(
    PaliGemmaForConditionalGeneration, PaliGemmaProcessor, TrainingArguments,
    Trainer, BitsAndBytesConfig, BitsAndBytesConfig
)
from datasets import load_dataset, DatasetDict

In [4]:
# verify gpu is available
torch.cuda.is_available()

True

In [5]:
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')

In [6]:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

## Load Dataset

In [7]:
# Load CNMC Dataset
ds = load_dataset("dwb2023/cnmc-leukemia-2019", split="train")

In [8]:
# Define the percentage you want to retrieve (e.g., 10%)
percentage = 0.10

# Use train_test_split to get the subset
cnmc_ds = ds.train_test_split(test_size=percentage)["test"]

cols_remove = ["subject_id", "image_number", "cell_count", "class_label",
               "fold", "original_image_name", "relative_file_path"]
cnmc_ds = cnmc_ds.remove_columns(cols_remove)

In [9]:
# create train test split with test_size=0.2
train_ds = cnmc_ds.train_test_split(test_size=0.2)

# create test val split
test_val_ds = train_ds["test"].train_test_split(test_size=0.5)

cnmc_ds_dict = DatasetDict({
    "train" : train_ds["train"],
    "test" : test_val_ds["test"],
    "validation" : test_val_ds["train"]
})

cnmc_ds_dict

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 853
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 107
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 107
    })
})

## Collate Data

In [10]:
from PIL import Image, UnidentifiedImageError
import io

class DatasetPreprocessor:
    def __init__(self, processor, max_seq_length):
        self.processor = processor
        self.max_seq_length = max_seq_length

    def preprocess_function(self, examples):
        texts = ["Are these cells healthy or cancerous?" for _ in range(len(examples['image']))]
        images = []
        valid_indices = []
        for i, img in enumerate(examples['image']):
            try:
                image = Image.open(io.BytesIO(img['bytes'])).convert("RGB")
                images.append(image)
                valid_indices.append(i)
            except UnidentifiedImageError:
                print(f"Unidentified image: {img['path']}")

        if not images:
            return {}  # Return an empty dict if no valid images

        valid_texts = [texts[i] for i in valid_indices]
        valid_labels = [examples['label'][i] for i in valid_indices]

        inputs = self.processor(
            text=valid_texts,
            images=images,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_seq_length,
            truncation=True
        )

        # Process labels separately
        label_inputs = self.processor.tokenizer(
            text=valid_labels,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_seq_length,
            truncation=True
        )
        inputs['labels'] = label_inputs.input_ids

        return inputs

    def preprocess_dataset(self, dataset):
        return dataset.map(
            self.preprocess_function,
            batched=True,
            remove_columns=dataset.column_names,
            desc="Preprocessing dataset"
        ).filter(lambda x: len(x) > 0)  # Filter out empty results

In [11]:
processor = PaliGemmaProcessor.from_pretrained(base_model_name)

In [12]:
max_seq_length = 128
output_dir = adapter_version

In [13]:
from PIL import Image
import io

def collate_fn(batch):
    texts = ["Are these cells healthy or cancerous?" for _ in batch]
    labels = [example['label'] for example in batch]
    images = [Image.open(io.BytesIO(example['image']['bytes'])).convert("RGB") for example in batch]

    inputs = processor(text=texts, images=images, return_tensors="pt", padding="max_length", max_length=max_seq_length, truncation=True)
    inputs['labels'] = processor(text=labels, return_tensors="pt", padding="max_length", max_length=max_seq_length, truncation=True).input_ids

    return inputs

## Load and Quatize the base Model (bitsandbytes)

In [14]:
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(base_model_name, quantization_config=bnb_config, device_map={"":0})
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


Unused kwargs: ['bnb_4bit_compute_type']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 11,298,816 || all params: 2,934,765,296 || trainable%: 0.3850


## Train the Adapter Model (trl)

In [15]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): PaliGemmaForConditionalGeneration(
      (vision_tower): SiglipVisionModel(
        (vision_model): SiglipVisionTransformer(
          (embeddings): SiglipVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
            (position_embedding): Embedding(256, 1152)
          )
          (encoder): SiglipEncoder(
            (layers): ModuleList(
              (0-26): 27 x SiglipEncoderLayer(
                (self_attn): SiglipAttention(
                  (k_proj): lora.Linear4bit(
                    (base_layer): Linear4bit(in_features=1152, out_features=1152, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Identity()
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1152, out_features=8, bias=False)
                    )
                    (lora_B): Mo

In [16]:
model.config

PaliGemmaConfig {
  "_name_or_path": "google/paligemma-3b-pt-224",
  "architectures": [
    "PaliGemmaForConditionalGeneration"
  ],
  "bos_token_id": 2,
  "eos_token_id": 1,
  "hidden_size": 2048,
  "ignore_index": -100,
  "image_token_index": 257152,
  "model_type": "paligemma",
  "pad_token_id": 0,
  "projection_dim": 2048,
  "quantization_config": {
    "_load_in_4bit": true,
    "_load_in_8bit": false,
    "bnb_4bit_compute_dtype": "float32",
    "bnb_4bit_quant_storage": "uint8",
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_use_double_quant": false,
    "llm_int8_enable_fp32_cpu_offload": false,
    "llm_int8_has_fp16_weight": false,
    "llm_int8_skip_modules": null,
    "llm_int8_threshold": 6.0,
    "load_in_4bit": true,
    "load_in_8bit": false,
    "quant_method": "bitsandbytes"
  },
  "text_config": {
    "hidden_size": 2048,
    "intermediate_size": 16384,
    "model_type": "gemma",
    "num_attention_heads": 8,
    "num_hidden_layers": 18,
    "num_image_tokens": 256,

In [17]:
from trl import SFTConfig

args=SFTConfig(
    output_dir = output_dir,
    max_steps=50,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    logging_steps=5,
    eval_strategy="steps",
    eval_steps=5,
    learning_rate=2e-5,
    max_seq_length=max_seq_length,
    bf16=True,
    optim="adamw_hf",
    report_to=["tensorboard"],
)

In [18]:
from PIL import Image
import io

preprocessor = DatasetPreprocessor(processor, max_seq_length)

# Print some information about the original datasets
print("Train dataset size:", len(cnmc_ds_dict['train']))
print("Validation dataset size:", len(cnmc_ds_dict['validation']))

# Preprocess the datasets
train_dataset = preprocessor.preprocess_dataset(cnmc_ds_dict['train'])
eval_dataset = preprocessor.preprocess_dataset(cnmc_ds_dict['validation'])

# Print information about the preprocessed datasets
print("Preprocessed train dataset size:", len(train_dataset))
print("Preprocessed validation dataset size:", len(eval_dataset))

# Check the first item in the preprocessed dataset
first_item = next(iter(train_dataset))
print("Keys in the first item:", first_item.keys())
print("Shape of input_ids:", first_item['input_ids'].shape)
print("Shape of labels:", first_item['labels'].shape)

Train dataset size: 853
Validation dataset size: 107


Preprocessing dataset:   0%|          | 0/853 [00:00<?, ? examples/s]

Unidentified image: UID_67_14_7_all.bmp
Unidentified image: UID_67_4_2_all.bmp
Unidentified image: UID_67_23_8_all.bmp
Unidentified image: UID_67_33_7_all.bmp
Unidentified image: UID_67_20_11_all.bmp
Unidentified image: UID_67_30_1_all.bmp
Unidentified image: UID_68_10_3_all.bmp
Unidentified image: UID_67_10_2_all.bmp
Unidentified image: UID_67_4_6_all.bmp
Unidentified image: UID_67_17_5_all.bmp
Unidentified image: UID_67_16_8_all.bmp
Unidentified image: UID_67_12_4_all.bmp
Unidentified image: UID_67_31_12_all.bmp
Unidentified image: UID_67_27_6_all.bmp
Unidentified image: UID_67_20_3_all.bmp
Unidentified image: UID_67_1_1_all.bmp
Unidentified image: UID_67_30_2_all.bmp
Unidentified image: UID_49_5_1_all.bmp
Unidentified image: UID_67_16_6_all.bmp
Unidentified image: UID_67_30_5_all.bmp
Unidentified image: UID_67_11_1_all.bmp
Unidentified image: UID_67_3_7_all.bmp
Unidentified image: UID_49_7_4_all.bmp
Unidentified image: UID_67_30_8_all.bmp
Unidentified image: UID_67_34_2_all.bmp


Filter:   0%|          | 0/828 [00:00<?, ? examples/s]

Preprocessing dataset:   0%|          | 0/107 [00:00<?, ? examples/s]

Unidentified image: UID_49_5_3_all.bmp


Filter:   0%|          | 0/106 [00:00<?, ? examples/s]

Preprocessed train dataset size: 828
Preprocessed validation dataset size: 106
Keys in the first item: dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])


AttributeError: 'list' object has no attribute 'shape'

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=args,
    tokenizer=processor,
    compute_metrics=None,
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub(f"dwb2023/{output_dir}")

### Temporary code below to assist with debugging Image Key error

In [None]:
cnmc_ds_dict['train'][0].keys()