## PaliGemma Fine-tuning

In this notebook, we will fine-tune [pretrained PaliGemma](https://huggingface.co/google/paligemma-3b-pt-448). I remove a lot of the code that was in the other (non-$448$-px) notebook.


In [1]:
%load_ext autoreload
%autoreload 2

# Training

Load the processor to preprocess the dataset.

In [2]:
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
import os
import dotenv
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dotenv.load_dotenv("/scratch/gsk6me/WORLDMODELS/crl-online-policies/.env")

token = os.environ['HUGGINGFACE_ACCESS_TOKEN']
model_id = "google/paligemma-3b-mix-448"
processor = PaliGemmaProcessor.from_pretrained(model_id, token=token)

# model = PaliGemmaForConditionalGeneration.from_pretrained("paligemma_object_classifier/checkpoint-37500", torch_dtype=torch.bfloat16, token=token, device_map=device)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=token, device_map=device)

for param in model.vision_tower.parameters():
    param.requires_grad = False

# for param in model.multi_modal_projector.parameters():
#     param.requires_grad = False


  from .autonotebook import tqdm as notebook_tqdm
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.81s/it]


In [3]:
import PIL.Image
import os

image_token = processor.tokenizer.convert_tokens_to_ids("<image>")

def collate_fn(image_root, examples):
    images = [PIL.Image.open(os.path.join(image_root, example['image'])).convert("RGB") for example in examples]
    texts = [example['prefix'] for example in examples]
    labels = [example['suffix'] for example in examples]
    
    tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest",
                    tokenize_newline_separately=False)

    tokens = tokens.to(torch.bfloat16).to(device)
    return tokens


Alternatively, if you want to do LoRA and QLoRA fine-tuning, you can run below cells to load the adapter either in full precision or quantized.

In [4]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [5]:
from peft import get_peft_model, LoraConfig

quantize = False
use_lora = False

if quantize:
    from transformers import BitsAndBytesConfig
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_type=torch.bfloat16
    )
    model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

if use_lora:
    lora_config = LoraConfig(
        r=32,
        target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


We can now start training.

## Pretrain on COCO object classification

This should hopefully improve the model's ability to localize information


In [6]:
import PIL.Image
import json
from pycocotools.coco import COCO
import tqdm
import os

image_size_cache = {}

def load_size(image_path):
    image = PIL.Image.open(image_path)
    return (image.width, image.height)

if not os.path.exists("coco_train2014_clsify.jsonl"):
    print("coco_train2014_clsify.jsonl does not exist. Creating!")
    
    coco_train2014 = COCO('/scratch/gsk6me/WORLDMODELS/coco_annotations/instances_train2014.json')
    
    coco_train2014_ds = []
    coco_image_root = "/scratch/gsk6me/WORLDMODELS"

    for image_id in tqdm.tqdm(sorted(coco_train2014.imgs.keys())):
        image_path = f"train2014/COCO_train2014_{image_id:012d}.jpg"
        # image = PIL.Image.open(f"{coco_image_root}/{image_path}")
        # image_width = image.width
        # image_height = image.height
        if image_path not in image_size_cache:
            image_size_cache[image_path] = load_size(os.path.join(coco_image_root, image_path))
        (image_width, image_height) = image_size_cache[image_path]
        
        annots = coco_train2014.imgToAnns[image_id]
        for annot in annots:
            category_id = annot['category_id']
            category_name = coco_train2014.cats[category_id]['name']
            bbox = annot['bbox']

            x0_quant = int((bbox[0] / image_width) * 1024)
            y0_quant = int((bbox[1] / image_height) * 1024)
            x1_quant = int(((bbox[0] + bbox[2]) / image_width) * 1024)
            y1_quant = int(((bbox[1] + bbox[3]) / image_height) * 1024)

            coco_train2014_ds.append({
                'image': image_path,
                'prefix': f'Describe <loc{x0_quant:04d}><loc{y0_quant:04d}><loc{x1_quant:04d}><loc{y1_quant:04d}>',
                'suffix': category_name,
                'bbox_x': bbox[0],
                'bbox_y': bbox[1],
                'bbox_w': bbox[2],
                'bbox_h': bbox[3],
                'image_width': image_width,
                'image_height': image_height,
            })

    with open("coco_train2014_clsify.jsonl", "w") as f:
        for record in coco_train2014_ds:
            json.dump(record, f)
            f.write("\n")
else:
    with open("coco_train2014_clsify.jsonl") as f:
        coco_train2014_ds = [json.loads(record) for record in f.read().split("\n") if record]


# Train on OCID-Ref

In [27]:
import json

def load_jsonl(paths):
    if type(paths) is str:
        paths = [paths]
    ds = []
    for path in paths:
        with open(path) as f:
            for line in f.readlines():
                ds.append(json.loads(line.strip()))
                
    return ds

train_ds = load_jsonl("ocid-ref-train.jsonl")
# train_ds = load_jsonl(['train_object_captions_1k.jsonl', 'train_object_captions_4k.jsonl'])
# train_ds = coco_train2014_ds

In [28]:
from transformers import TrainingArguments

args = TrainingArguments(
    num_train_epochs=1,
    remove_unused_columns=False,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=128,
    optim="adamw_hf",
    save_strategy="epoch",
    # save_strategy="steps",
    # save_steps=500,
    push_to_hub=False,
    save_total_limit=1,
    # NO pretraining!!
    output_dir="paligemma_object_captioner_448_ocidref-train",
    # output_dir="paligemma_object_captioner_pretrained",
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False
)


In [29]:
from transformers import Trainer
from functools import partial

coco_root = '/scratch/gsk6me/WORLDMODELS'
ocid_root = '/scratch/gsk6me/WORLDMODELS/OCID-dataset'

trainer = Trainer(
    model=model,
    train_dataset=train_ds,
    data_collator=partial(collate_fn, ocid_root),
    args=args
)


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trainer.train()



Step,Training Loss
128,1.1178
256,0.8774
384,0.7752
512,0.7068
640,0.6795
768,0.6327
896,0.614
1024,0.5974
1152,0.5784
1280,0.5648


In [14]:
trainer.save_model('paligemma_object_captioner_pretrained_448/checkpoint-544')

In [None]:
trainer.push_to_hub()

You can find steps to infer [here](https://colab.research.google.com/drive/100IQcvMvGm9y--oelbLfI__eHCoz5Ser?usp=sharing).