### Import Libraries

In [1]:
from datasets import load_dataset
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
from torch.utils.data import Dataset
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AdamW, AutoProcessor, get_scheduler)
from transformers import TrainingArguments, Trainer
from transformers import AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig
from huggingface_hub import notebook_login

### Dataset 

In [2]:
dataset = load_dataset("gokulsabari/brain_scan")

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'id', 'caption', 'ground_truth'],
        num_rows: 136960
    })
    test: Dataset({
        features: ['image', 'id', 'caption', 'ground_truth'],
        num_rows: 24170
    })
})

### Model Loading

#### The model has been trained using RTX 4090 24GB VRAM

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
print(device)
print(torch.cuda.get_device_name())

cuda
NVIDIA GeForce RTX 4090


### Base Model - Paligemma

In [6]:
model_id = "google/paligemma-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(model_id)

### Dataset Building

#### Defining collate function

In [38]:
image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
question = """Analyze the provided medical brain imaging scan. 
Describe the type of scan, key anatomical structures visible, and any abnormal findings.
Focus on: Identifying the imaging modality (e.g., CT, MRI). Describing the view and orientation of the scan
Noting any visible anatomical structures. Identifying and describing any abnormalities or areas of concern 
along with their approximate area as a % of the total image area
Specifying the location of abnormalities using anatomical terms. Suggesting possible clinical implications of the findings
Provide a concise yet comprehensive analysis in a professional medical tone."""

def collate_fn(examples):
  texts = [question for _ in examples]
  labels= [example['caption'] for example in examples]
  images = [example["image"].convert("RGB") for example in examples]
  tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest",
                    tokenize_newline_separately=False)

  tokens = tokens.to(torch.bfloat16).to(device)
  return tokens

In [9]:
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
model

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(256, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (self_attn): SiglipSdpaAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features

### We have freezed the vision encoder and text decoder and unfreezed the projector.

In [10]:
for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.language_model.parameters():
    param.requires_grad = False

In [8]:
# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=8,
#     lora_dropout=0.1,
#     target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
#     init_lora_weights="gaussian"
# )
# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True,            
# )
# model = PaliGemmaForConditionalGeneration.from_pretrained(
#     model_id,
#     torch_dtype=torch.bfloat16,
#     quantization_config=bnb_config
# )
# model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
# model.add_adapter(lora_config)
# model.enable_adapters()


### Defining the training arguments

In [11]:
args = TrainingArguments(
    num_train_epochs=100,
    remove_unused_columns=False,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=100,
    optim="adamw_hf",
    save_strategy="epoch",
    save_steps=1000,
    push_to_hub=True,
    save_total_limit=1,
    output_dir="paligemma-adapter",
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False,
    evaluation_strategy="epoch",  
    eval_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    # Early stopping parameters
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

In [12]:
trainer = Trainer(
        model=model,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        data_collator=collate_fn,
        args=args
        )

In [13]:
trainer.train()



Epoch,Training Loss,Validation Loss
1,1.0288,1.035887
2,1.0149,1.013301
3,1.0197,1.009638
4,1.0073,1.008637
5,1.0054,1.008433
6,1.0112,1.008285
7,1.0173,1.008306
8,1.0175,1.008329
9,1.0181,1.008236
10,1.0022,1.008252


HTTP Error 500 thrown while requesting PUT https://hf-hub-lfs-us-east-1.s3-accelerate.amazonaws.com/repos/76/2f/762fd2ac66a397735c383fc1e40352f14e1e313215b0dd85eb681b45342780e8/2d6c497917bc06d805d06c9dc46aed6fbaa2af04295afb867942daa6b08afdd4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=AKIA2JU7TKAQLC2QXPN7%2F20240903%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240903T221755Z&X-Amz-Expires=86400&X-Amz-Signature=b314115850e3e51d74dc6e8c894ef39c7dedd5bc04dfb5721e40eca0566fe894&X-Amz-SignedHeaders=host&partNumber=252&uploadId=6XL1ib9I5PNZVyybqPhnBcidLxscGm2V9hrWxpagmnE3XD1Liunpyn__mqlRDMSumwtU9o.9a2hygQ7spfx9QXcBEF2DXtOA0y_q.DkUC6.srSZ8qkuKq0Yjm7ffZwsP&x-id=UploadPart
Retrying in 1s [Retry 1/5].
There were missing keys in the checkpoint model loaded: ['language_model.lm_head.weight'].


TrainOutput(global_step=85600, training_loss=1.02722936059827, metrics={'train_runtime': 101156.4145, 'train_samples_per_second': 13.539, 'train_steps_per_second': 0.846, 'total_flos': 9.987645287412793e+18, 'train_loss': 1.02722936059827, 'epoch': 10.0})

In [14]:
trainer.push_to_hub()



CommitInfo(commit_url='https://huggingface.co/gokulsabari/paligemma-adapter/commit/0c555ec77eda429e593aca62dfd188954f6270d1', commit_message='End of training', commit_description='', oid='0c555ec77eda429e593aca62dfd188954f6270d1', pr_url=None, pr_revision=None, pr_num=None)