In [None]:
!pip install peft
!pip install datasets transformers pillow
!pip install -U bitsandbytes

Visit the huggingface page for paligemma and accept the terms there to use paligemma here.

In [None]:
!huggingface-cli login

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from PIL import Image
from peft import get_peft_model, LoraConfig


In [None]:
'''Replace dataset here with any VQA dataset, if the dataset has the same format as VQAv2, no need to change anything,
otherwise adjust columns to remove and the question and target column ahead accordingly'''
dataset = load_dataset("pminervini/VQAv2")
train_ds = dataset['train']

# Remove unnecessary columns if any
try:
    cols_remove = ["question_type", "answers", "answer_type", "image_id", "question_id"]
    train_ds = train_ds.remove_columns([col for col in cols_remove if col in train_ds.column_names])
except:
    print("Some columns don't exist in this dataset, continuing...")


In [None]:
import torch
from transformers import PaliGemmaProcessor
from torchvision import transforms
from PIL import Image

# Initialize the processor for your model
model_id = "google/paligemma-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(model_id)

# Define a transform to convert the image to a tensor and resize
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224 for the model
    transforms.ToTensor(),
])

def collate_fn(examples):

    texts = [f"<image> <bos> answer {example['question']}" for example in examples]
    labels = [str(example.get('answer', example.get('multiple_choice_answer', ''))) for example in examples]

    images = []
    for example in examples:
        image = example["image"]
        image_tensor = transform(image)

        # Check if the image is grayscale (1 channel)
        if image_tensor.shape[0] == 1:
            image_tensor = image_tensor.repeat(3, 1, 1)  # Convert grayscale to RGB by repeating the channel
        images.append(image_tensor)

    # Process inputs with the processor
    model_inputs = processor(
        text=texts,
        images=images,
        padding="max_length",
        truncation=True,
        max_length=256,      # Adjust max_length as needed
        return_tensors="pt",
        do_rescale=False  # Disable rescaling if images are already scaled between 0 and 1
    )

    # Process targets (labels)
    with processor.tokenizer.as_target_tokenizer():
        label_encoding = processor.tokenizer(
            labels,
            padding="max_length",
            truncation=True,
            # Ensure max_length for labels is consistent with input
            max_length=model_inputs["input_ids"].shape[1],
            return_tensors="pt"
        )

    # Create attention mask for labels (1 for real tokens, 0 for padding)
    labels = label_encoding["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Replace padding with -100
    model_inputs["labels"] = labels

    processed_inputs = {}
    for k, v in model_inputs.items():
        if k in ['input_ids', 'attention_mask', 'labels']:
            # Keep these as Long tensors
            processed_inputs[k] = v.to("cuda")
        else:
            # Convert other tensors (like pixel_values) to bfloat16
            processed_inputs[k] = v.to(torch.bfloat16).to("cuda")

    return processed_inputs


In [None]:
# Configure quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_type=torch.bfloat16
)

# Configure LoRA
lora_config = LoraConfig(
    r=8,
    target_modules=[
        "q_proj", "o_proj", "k_proj", "v_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    task_type="CAUSAL_LM",
)

# Initialize model with quantization and LoRA
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model = get_peft_model(model, lora_config)


In [None]:
# Configure training arguments
training_args = TrainingArguments(
    num_train_epochs=2,
    remove_unused_columns=False,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=100,
    optim="adamw_hf",
    save_strategy="steps",
    save_steps=10,
    push_to_hub=False,
    save_total_limit=1,
    output_dir="/content/out",
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False,
)


In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    train_dataset=train_ds,
    data_collator=collate_fn,
    args=training_args
)

# Start training
trainer.train()


In [None]:
# Configure quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_type=torch.bfloat16
)

In [None]:
from safetensors.torch import load_file
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
import torch

In [None]:

model_checkpoint_path = "/content/out/checkpoint-250/adapter_model.safetensors"  # Model path


In [None]:
model_name = "google/paligemma-3b-mix-224"  # Base model name
processor = AutoProcessor.from_pretrained(model_name)

In [None]:
model = AutoModelForVision2Seq.from_pretrained(
    model_name,
    device_map="auto"
)

In [None]:
state_dict = load_file(model_checkpoint_path)
model.load_state_dict(state_dict, strict=False)
model.eval()

In [None]:
image_path = "./output.jpg"  # Image file path
text_prompt = "<image> <bos>What is shown in this image ?"  # Prompt with special tokens
image = Image.open(image_path).convert("RGB")
image=image.resize((32,32), Image.BICUBIC)
image=image.resize((224,224), Image.BICUBIC)
inputs = processor(images=image, text=text_prompt, return_tensors="pt")


with torch.no_grad():
    outputs = model.generate(**inputs, max_length=400)  # Adjust max_length as needed

response = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Model's Response: {response}")
