## PaliGemma Fine-tuning

In this notebook, we will fine-tune [pretrained PaliGemma](https://huggingface.co/google/paligemma-3b-pt-448) on a small split of [VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2) dataset. Let's get started by installing necessary libraries.

In [19]:
%load_ext autoreload
%autoreload 2

# !pip install -q -U git+https://github.com/huggingface/transformers.git datasets accelerate

We will authenticate to access the model using `notebook_login()`.

In [2]:
import pandas as pd

# https://stackoverflow.com/questions/50475635/loading-jsonl-file-as-json-objects
dataset = {
    split: pd.read_json(path_or_buf=f"./refcocog_{split}.jsonl", lines=True)
    for split in ['train', 'test', 'val']
}


# Cleaning

In [3]:
dataset['train']

Unnamed: 0,image,prefix,suffix
0,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0000><loc0098><loc0382><loc0871>,Answer: two woman one in black eatting and the...
1,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0000><loc0098><loc0382><loc0871>,Answer: woman in white shirt looking down at l...
2,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0357><loc0763><loc0321><loc0223>,Answer: a tv with a woman being interviewed on it
3,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0357><loc0763><loc0321><loc0223>,Answer: a woman with sunglasses on her head on...
4,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0150><loc0110><loc0557><loc0552>,Answer: a young boy doing a skateboard trick o...
...,...,...,...
80507,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0226><loc0049><loc0574><loc0346>,Answer: the larger banana is above the small r...
80508,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0226><loc0230><loc0184><loc0768>,Answer: a guy in black jacket and cowboy hat
80509,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0226><loc0230><loc0184><loc0768>,Answer: the man in the dark coat and pointier hat
80510,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0084><loc0107><loc0939><loc0916>,Answer: a person in red dress and he is seeing...


In [4]:
def clean_label(label):
    label = label[len('Answer: '):]
    
    # Remove unnecessary articles
    if label.startswith("a "):
        label = label[2:]
    if label.startswith("the "):
        label = label[4:]
        
    # Try to make predictions as clean as possible
    return label

In [5]:
dataset['train']['suffix'] = dataset['train']['suffix'].apply(clean_label)

In [6]:
dataset['train'].head()

Unnamed: 0,image,prefix,suffix
0,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0000><loc0098><loc0382><loc0871>,two woman one in black eatting and the other h...
1,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0000><loc0098><loc0382><loc0871>,woman in white shirt looking down at laptop co...
2,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0357><loc0763><loc0321><loc0223>,tv with a woman being interviewed on it
3,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0357><loc0763><loc0321><loc0223>,woman with sunglasses on her head on the telev...
4,/scratch/gsk6me/WORLDMODELS/train2014/COCO_tra...,Describe <loc0150><loc0110><loc0557><loc0552>,young boy doing a skateboard trick on a blue b...


In [7]:
dataset['train']['suffix'].iloc[5:10].tolist()

['man jumping with a skateboard',
 'long - horn , long - haired brown cow looking at the camera',
 'brown bull in front of feeding tub',
 'woman in black dress',
 'lady in a black dress cuts a wedding cake with her new husband']

Can we clean the grammar? This grammar is horrendous. Maybe we can use Phi-2 for this.

In [8]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

phi2_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True, device_map="cuda")
phi2_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.29it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [27]:

phi2_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True, padding_side='left')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
### Because we put a lot of examples in the prompt, we can use a KV cache to speed up inference. ###

prefix_prompt = f"""
Convert the input into a grammatically-correct phrase, preserving all details.

Sentence: two woman one in black eatting and the other has a white shirt at the desk

Extraction: two women, one in black eating, and the other with a white shirt at the desk

Sentence: young boy doing a skateboard trick on a blue board

Extraction: young boy doing a skateboard trick on a blue board

Sentence: lady in a black dress cuts a wedding cake with her new husband

Extraction: lady in a black dress cutting a wedding cake with her new husband

Sentence: woman in motion

Extraction: woman in motion

Sentence:"""

inputs = phi2_tokenizer(prefix_prompt, return_tensors="pt")
inputs = inputs.to("cuda")
with torch.no_grad():
    output = phi2_model(**inputs, use_cache=True)
past_key_values = output.past_key_values

In [62]:
from transformers import StoppingCriteria

# We implement a "stopping criterion" to prevent generation from going too far.
# https://huggingface.co/PygmalionAI/pygmalion-6b/discussions/25
class MyStoppingCriteria(StoppingCriteria):
    def __init__(self, target_sequence, n_prompt_tokens, tokenizer):
        self.target_sequence = target_sequence
        self.n_prompt_tokens = n_prompt_tokens
        self.tokenizer = tokenizer

    def __call__(self, token_ids, scores, **kwargs):
        # Get the generated text as a string
        output_ids = token_ids[:, self.n_prompt_tokens:]
        return all('\n' in s for s in self.tokenizer.batch_decode(output_ids))

    def __len__(self):
        return 1

    def __iter__(self):
        yield self

In [63]:
def correct_sentences(sentences):
    prompts = [f"""{prefix_prompt} {sentence}

Extraction:
""".strip() for sentence in sentences]
    
    # Enable padding for batching.
    phi2_tokenizer.pad_token = phi2_tokenizer.eos_token
    inputs = phi2_tokenizer(prompts, return_tensors="pt", padding=True)
    inputs = inputs.to("cuda")

    stopper = MyStoppingCriteria('\n', len(inputs.input_ids[0]), phi2_tokenizer)

    with torch.no_grad():
        outputs = phi2_model.generate(
            **inputs,
            max_new_tokens=30,
            stopping_criteria=stopper,
            pad_token_id=phi2_tokenizer.eos_token_id,
            # KV cache
            # past_key_values=past_key_values.expand(len(sentences), *(-1 for _ in range(len(past_key_values
        )
        texts = phi2_tokenizer.batch_decode(outputs[:, len(inputs[0]):])
        texts = [texts[i].split("\n")[0].strip() for i in range(len(prompts))]
        
    return texts

In [65]:
import tqdm
import concurrent.futures

corrected = []
suffixes = dataset['train']['suffix'].tolist()

i = 0

with tqdm.tqdm(desc='correcting labels...', total=len(suffixes)) as pbar:
    while i < len(suffixes):
        # Use a batch size of 8.
        corrected.extend(correct_sentences(suffixes[i:i + 8]))
        i += 8
        pbar.update(len(suffixes[i:i + 8]))


correcting labels...: 100%|█████████▉| 80504/80512 [1:15:25<00:00, 17.79it/s]


In [66]:
import json
with open("train_corrected.json", "w") as f:
    json.dump(corrected, f)
    
with open("train_original.json", "w") as f:
    json.dump(suffixes, f)

In [70]:
dataset['train']['suffix'] = suffixes

# Training

Load the processor to preprocess the dataset.

In [47]:
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
import os
import dotenv
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dotenv.load_dotenv("/scratch/gsk6me/WORLDMODELS/crl-online-policies/.env")

token = os.environ['HUGGINGFACE_ACCESS_TOKEN']
model_id = "google/paligemma-3b-mix-224"
processor = PaliGemmaProcessor.from_pretrained(model_id, token=token)

model = PaliGemmaForConditionalGeneration.from_pretrained("paligemma_object_classifier/checkpoint-37500", torch_dtype=torch.bfloat16, token=token, device_map=device)
# model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=token, device_map=device)

for param in model.vision_tower.parameters():
    param.requires_grad = False

# for param in model.multi_modal_projector.parameters():
#     param.requires_grad = False


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.79s/it]


In [37]:
import PIL.Image
import os

image_token = processor.tokenizer.convert_tokens_to_ids("<image>")

coco_root = '/scratch/gsk6me/WORLDMODELS'

def collate_fn(examples):
    images = [PIL.Image.open(os.path.join(coco_root, example['image'])).convert("RGB") for example in examples]
    texts = [example['prefix'] for example in examples]
    labels = [example['suffix'] for example in examples]
    
    tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest",
                    tokenize_newline_separately=False)

    tokens = tokens.to(torch.bfloat16).to(device)
    return tokens


Alternatively, if you want to do LoRA and QLoRA fine-tuning, you can run below cells to load the adapter either in full precision or quantized.

In [38]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [23]:
from peft import get_peft_model, LoraConfig

quantize = False
use_lora = False

if quantize:
    from transformers import BitsAndBytesConfig
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_type=torch.bfloat16
    )
    model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

if use_lora:
    lora_config = LoraConfig(
        r=32,
        target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


We will now initialize the `TrainingArguments`.

In [48]:
from transformers import TrainingArguments

args = TrainingArguments(
    num_train_epochs=3,
    remove_unused_columns=False,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=128,
    optim="adamw_hf",
    save_strategy="epoch",
    # save_strategy="steps",
    # save_steps=500,
    push_to_hub=False,
    save_total_limit=1,
    output_dir="paligemma_object_captioner_pretrained",
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False
)


We can now start training.

## Pretrain on COCO object classification

This should hopefully improve the model's ability to localize information


In [3]:
from pycocotools.coco import COCO
coco_train2014 = COCO('/scratch/gsk6me/WORLDMODELS/coco_annotations/instances_train2014.json')

loading annotations into memory...
Done (t=9.49s)
creating index...
index created!


In [14]:
import PIL.Image
import json

if not os.path.exists("coco_train2014_clsify.jsonl"):
    coco_train2014_ds = []
    coco_image_root = "/scratch/gsk6me/WORLDMODELS"

    for image_id in sorted(coco_train2014.imgs.keys()):
        image_path = f"train2014/COCO_train2014_{image_id:012d}.jpg"
        image = PIL.Image.open(f"{coco_image_root}/{image_path}")
        image_width = image.width
        image_height = image.height
        annots = coco_train2014.imgToAnns[image_id]
        for annot in annots:
            category_id = annot['category_id']
            category_name = coco_train2014.cats[category_id]['name']
            bbox = annot['bbox']

            x0_quant = int((bbox[0] / image_width) * 1024)
            y0_quant = int((bbox[1] / image_height) * 1024)
            x1_quant = int(((bbox[0] + bbox[2]) / image_width) * 1024)
            y1_quant = int(((bbox[1] + bbox[3]) / image_height) * 1024)

            coco_train2014_ds.append({
                'image': image_path,
                'prefix': f'Describe <loc{x0_quant:04d}><loc{y0_quant:04d}><loc{x1_quant:04d}><loc{y1_quant:04d}>',
                'suffix': category_name,
                'bbox_x': bbox[0],
                'bbox_y': bbox[1],
                'bbox_w': bbox[2],
                'bbox_h': bbox[3],
                'image_width': image_width,
                'image_height': image_height,
            })

    with open("coco_train2014_clsify.jsonl", "w") as f:
        for record in coco_train2014_ds:
            json.dump(record, f)
            f.write("\n")
else:
    with open("coco_train2014_clsify.jsonl", "w") as f:
        coco_train2014_ds = [json.loads(record) for record in f.read().split("\n") if record]


In [42]:
# Converts pandas dataframe to a list of dictionaries
import json
# train_ds = dataset['train'].to_dict('records')
with open("train_object_captions_1k.jsonl") as f:
    train_ds = [json.loads(line) for line in f.read().split("\n") if line]
with open("train_object_captions_4k.jsonl") as f:
    train_ds.extend([json.loads(line) for line in f.read().split("\n") if line])

In [43]:
train_ds[6], train_ds[7]

({'image': 'train2014/COCO_train2014_000000480023.jpg',
  'prefix': 'Describe <loc0481><loc0177><loc0626><loc0322>',
  'suffix': 'A blurred figure with a dark silhouette, possibly a person with a round head and no visible hair, wearing a dark garment.',
  'bbox_x': 225.76,
  'bbox_y': 111.06,
  'bbox_w': 67.77,
  'bbox_h': 90.7,
  'image_width': 480,
  'image_height': 640},
 {'image': 'train2014/COCO_train2014_000000480023.jpg',
  'prefix': 'Describe <loc0097><loc0644><loc0997><loc1018>',
  'suffix': "A person's hand holding a hot dog with mustard and onions.",
  'bbox_x': 45.83,
  'bbox_y': 402.96,
  'bbox_w': 421.92,
  'bbox_h': 233.88,
  'image_width': 480,
  'image_height': 640})

In [49]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    train_dataset=train_ds,
    data_collator=collate_fn,
    args=args
)


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [50]:
trainer.train()



Step,Training Loss
128,1.3228
256,0.973
384,0.8429
512,0.7361
640,0.7173
768,0.6228
896,0.6283


TrainOutput(global_step=969, training_loss=0.8174074289107347, metrics={'train_runtime': 935.346, 'train_samples_per_second': 16.576, 'train_steps_per_second': 1.036, 'total_flos': 6.42906476522016e+16, 'train_loss': 0.8174074289107347, 'epoch': 3.0})

In [14]:
trainer.save_model('paligemma_object_captioner_more_v2/checkpoint-544')

In [51]:
trainer.push_to_hub()

HfHubHTTPError: 401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create (Request ID: Root=1-665c7d2e-0b0ff6624a38369f27b0acc2;4cce9852-4661-4431-a287-7058fa59c4bb)

Invalid username or password.

You can find steps to infer [here](https://colab.research.google.com/drive/100IQcvMvGm9y--oelbLfI__eHCoz5Ser?usp=sharing).