## PaliGemma Fine-tuning

In this notebook, we will fine-tune [pretrained PaliGemma](https://huggingface.co/google/paligemma-3b-pt-448) on a small split of [VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2) dataset. Let's get started by installing necessary libraries.

In [1]:
%load_ext autoreload
%autoreload 2

# !pip install -q -U git+https://github.com/huggingface/transformers.git datasets accelerate

We will authenticate to access the model using `notebook_login()`.

In [2]:
import pandas as pd

# https://stackoverflow.com/questions/50475635/loading-jsonl-file-as-json-objects
dataset = {
    split: pd.read_json(path_or_buf=f"./refcocog_{split}.jsonl", lines=True)
    for split in ['train', 'test', 'val']
}


In [3]:
dataset['train']

Unnamed: 0,image,prefix,suffix
0,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0000><loc0098><loc0382><loc0871>,Answer: two woman one in black eatting and the...
1,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0000><loc0098><loc0382><loc0871>,Answer: woman in white shirt looking down at l...
2,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0357><loc0763><loc0321><loc0223>,Answer: a tv with a woman being interviewed on it
3,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0357><loc0763><loc0321><loc0223>,Answer: a woman with sunglasses on her head on...
4,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0150><loc0110><loc0557><loc0552>,Answer: a young boy doing a skateboard trick o...
...,...,...,...
80507,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0226><loc0049><loc0574><loc0346>,Answer: the larger banana is above the small r...
80508,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0226><loc0230><loc0184><loc0768>,Answer: a guy in black jacket and cowboy hat
80509,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0226><loc0230><loc0184><loc0768>,Answer: the man in the dark coat and pointier hat
80510,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0084><loc0107><loc0939><loc0916>,Answer: a person in red dress and he is seeing...


Load the processor to preprocess the dataset.

In [4]:
from transformers import PaliGemmaProcessor
import os
import dotenv
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dotenv.load_dotenv("/scratch/gsk6me/WORLDMODELS/crl-online-policies/.env")

token = os.environ['HUGGINGFACE_ACCESS_TOKEN']
model_id = "google/paligemma-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(model_id, token=token)

  from .autonotebook import tqdm as notebook_tqdm


We will preprocess our examples. We need to prepare a prompt template and pass the text input inside, pass it with batches of images to processor. Then we will set the pad tokens and image tokens to -100 to let the model ignore them. We will pass our preprocessed input as labels to make the model learn how to generate responses.

In [22]:
import PIL.Image

image_token = processor.tokenizer.convert_tokens_to_ids("<image>")

def collate_fn(examples):
    images = [PIL.Image.open(example['image']).convert("RGB") for example in examples]
    texts = [example['prefix'] for example in examples]
    labels = [example['suffix'] for example in examples]
    
    tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest",
                    tokenize_newline_separately=False)

    tokens = tokens.to(torch.bfloat16).to(device)
    return tokens


Our dataset is a very general one and similar to many datasets that PaliGemma was trained with. In this case, we do not need to fine-tune the image encoder, the multimodal projector but we will only fine-tune the text decoder.

In [5]:
from transformers import PaliGemmaForConditionalGeneration
import torch

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=token, device_map=device)

for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.multi_modal_projector.parameters():
    param.requires_grad = False


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.26s/it]


Alternatively, if you want to do LoRA and QLoRA fine-tuning, you can run below cells to load the adapter either in full precision or quantized.

In [6]:
from peft import get_peft_model, LoraConfig

quantize = False

if quantize:
    from transformers import BitsAndBytesConfig
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_type=torch.bfloat16
    )
    model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

lora_config = LoraConfig(
    r=32,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


trainable params: 45,195,264 || all params: 2,968,661,744 || trainable%: 1.5224


We will now initialize the `TrainingArguments`.

In [7]:
from transformers import TrainingArguments

args = TrainingArguments(
    num_train_epochs=2,
    remove_unused_columns=False,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=100,
    optim="adamw_hf",
    save_strategy="steps",
    save_steps=1000,
    push_to_hub=False,
    save_total_limit=1,
    output_dir="paligemma_refcocog",
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False
)


We can now start training.

In [23]:
from transformers import Trainer

# Converts pandas dataframe to a list of dictionaries
train_ds = dataset['train'].to_dict('records')
trainer = Trainer(
    model=model,
    train_dataset=train_ds,
    data_collator=collate_fn,
    args=args
)


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trainer.train()



Step,Training Loss
100,3.3246
200,2.1091
300,1.992
400,2.0022
500,1.947
600,1.897
700,1.9098
800,1.88
900,1.9337
1000,1.8758



Cannot access gated repo for url https://huggingface.co/google/paligemma-3b-pt-224/resolve/main/config.json.
Access to model google/paligemma-3b-pt-224 is restricted. You must be authenticated to access it. - silently ignoring the lookup for the file config.json in google/paligemma-3b-pt-224.


In [None]:
trainer.push_to_hub()

You can find steps to infer [here](https://colab.research.google.com/drive/100IQcvMvGm9y--oelbLfI__eHCoz5Ser?usp=sharing).