In [1]:
import os
from pathlib import Path
import time
import math

import torch
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoImageProcessor,
    CLIPModel,
    HfArgumentParser,
    TrainingArguments,
    Trainer,
    set_seed,
)
from peft import LoraConfig, get_peft_model, PeftModel

from torchvision.transforms import (
    Resize, CenterCrop, ToTensor, Normalize, Compose, InterpolationMode
)
from PIL import Image

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

  from .autonotebook import tqdm as notebook_tqdm


Device: cpu


In [2]:

# EDIT THESE 
DATASET_NAME = "arampacha/rsicd"        # None if using local files
LOCAL_TRAIN_FILE = None                # "train.json" or "train.csv" if not using HF dataset
LOCAL_VALID_FILE = None                # optional
IMAGE_COLUMN = "image"                 # adjust to dataset's column name (e.g., "image" or "image_path")
CAPTION_COLUMN = "captions"            # adjust to dataset's caption column name
MODEL_NAME = "openai/clip-vit-base-patch32"
OUTPUT_DIR = "./clip-lora-output"
SEED = 42

# Training hyperparams
PER_DEVICE_TRAIN_BATCH_SIZE = 8
PER_DEVICE_EVAL_BATCH_SIZE = 8
LEARNING_RATE = 5e-5
NUM_TRAIN_EPOCHS = 3
MAX_SEQ_LENGTH = 77   # CLIP default
SAVE_STEPS = 500
LOGGING_STEPS = 100

# LoRA hyperparams (sane defaults)
LORA_R = 64
LORA_ALPHA = 128
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj"]  # typical for HF CLIP

In [3]:

def list_module_names(model, prefix=""):
    names = []
    for name, module in model.named_modules():
        names.append(name)
    return names


# model_tmp = CLIPModel.from_pretrained(MODEL_NAME)
# print("\n".join([n for n in list_module_names(model_tmp) if "attn" in n or "q_proj" in n or "proj" in n][:200]))

In [4]:
set_seed(SEED)

if DATASET_NAME is not None:
    dataset = load_dataset(DATASET_NAME)
    print("Loaded HF dataset:", dataset)
else:
    # Expect CSV or JSON files
    data_files = {}
    if LOCAL_TRAIN_FILE:
        data_files["train"] = LOCAL_TRAIN_FILE
    if LOCAL_VALID_FILE:
        data_files["validation"] = LOCAL_VALID_FILE
    assert data_files, "Either set DATASET_NAME or provide local train/validation files."
    ext = LOCAL_TRAIN_FILE.split(".")[-1]
    dataset = load_dataset(ext, data_files=data_files)
    print("Loaded local dataset:", dataset)

if "valid" in dataset and "validation" not in dataset:
    print("im changing it")
    dataset["validation"] = dataset["valid"]
    print(dataset["validation"])


# if "train" in dataset:
#     dataset["train"] = dataset["train"].select(range(min(16, len(dataset["train"]))))
    

# if "validation" in dataset:
#     dataset["validation"] = dataset["validation"].select(range(min(16, len(dataset["validation"]))))
   
# def preprocess(example):
#     # image → pixel_values
#     image = example["image"]
#     example["pixel_values"] = image_processor(image, return_tensors="pt")["pixel_values"][0]

#     # caption → input_ids + attention_mask
#     text = example["captions"][0] if isinstance(example["captions"], list) else example["captions"]
#     text_tokens = tokenizer(text, truncation=True, padding="max_length", max_length=77)
#     example["input_ids"] = text_tokens["input_ids"]
#     example["attention_mask"] = text_tokens["attention_mask"]
#         return example

print("Loaded local dataset:", dataset)


Generating train split: 100%|████████████████████████████████████████████████████| 8734/8734 [00:00<00:00, 10262.29 examples/s]
Generating test split: 100%|█████████████████████████████████████████████████████| 1093/1093 [00:00<00:00, 10044.31 examples/s]
Generating valid split: 100%|█████████████████████████████████████████████████████| 1094/1094 [00:00<00:00, 6104.61 examples/s]


Loaded HF dataset: DatasetDict({
    train: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 8734
    })
    test: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1093
    })
    valid: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1094
    })
})
im changing it
Dataset({
    features: ['filename', 'captions', 'image'],
    num_rows: 1094
})
Loaded local dataset: DatasetDict({
    train: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 8734
    })
    test: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1093
    })
    valid: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1094
    })
    validation: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1094
    })
})


In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

# normalization values from processor to match pretrained CLIP
image_size = image_processor.size["shortest_edge"] if isinstance(image_processor.size, dict) else getattr(image_processor, "image_mean", 224)
image_mean = image_processor.image_mean
image_std = image_processor.image_std

print("image_size:", image_size, "mean/std:", image_mean, image_std)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


image_size: 224 mean/std: [0.48145466, 0.4578275, 0.40821073] [0.26862954, 0.26130258, 0.27577711]


In [6]:
from torchvision.transforms import RandomResizedCrop

train_transform = Compose([
    Resize(int(image_size * 1.15), interpolation=InterpolationMode.BICUBIC),
    RandomResizedCrop(int(image_size), scale=(0.8, 1.0), interpolation=InterpolationMode.BICUBIC),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std),
])

eval_transform = Compose([
    Resize(int(image_size * 1.15), interpolation=InterpolationMode.BICUBIC),
    CenterCrop(int(image_size)),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std),
])

In [7]:
def tokenize_captions(examples):
    # captions might be list-of-lists depending on dataset; unify to string
    caps = [c[0] if isinstance(c, list) else c for c in examples[CAPTION_COLUMN]]
    tokenized = tokenizer(caps, max_length=MAX_SEQ_LENGTH, padding="max_length", truncation=True)
    examples["input_ids"] = tokenized["input_ids"]
    examples["attention_mask"] = tokenized["attention_mask"]
    return examples


def transform_images(examples, is_train=True):
    imgs = examples[IMAGE_COLUMN]
    out = []
    transform = train_transform if is_train else eval_transform
    for im in imgs:
        if isinstance(im, str):
            im = Image.open(im).convert("RGB")
        elif isinstance(im, dict) and "path" in im: 
            im = Image.open(im["path"]).convert("RGB")
        # else assume PIL.Image
        out.append(transform(im))
    examples["pixel_values"] = out
    return examples

In [8]:
print(dataset)
if "train" in dataset:
    dataset["train"] = dataset["train"].map(tokenize_captions, batched=True, remove_columns=[c for c in dataset["train"].column_names if c not in [IMAGE_COLUMN, CAPTION_COLUMN]])
    dataset["train"].set_transform(lambda x: transform_images(x, is_train=True))
if "validation" in dataset or "valid" in dataset:
    if "valid" in dataset and "validation" not in dataset:
        dataset["validation"] = dataset.pop("valid")
    dataset["validation"] = dataset["validation"].map(tokenize_captions, batched=True, remove_columns=[c for c in dataset["validation"].column_names if c not in [IMAGE_COLUMN, CAPTION_COLUMN]])
    dataset["validation"].set_transform(lambda x: transform_images(x, is_train=False))

print("Train/Validation ready.")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 8734
    })
    test: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1093
    })
    valid: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1094
    })
    validation: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1094
    })
})


Map: 100%|███████████████████████████████████████████████████████████████████████| 8734/8734 [00:00<00:00, 11051.32 examples/s]
Map: 100%|████████████████████████████████████████████████████████████████████████| 1094/1094 [00:00<00:00, 8223.85 examples/s]

Train/Validation ready.
DatasetDict({
    train: Dataset({
        features: ['captions', 'image', 'input_ids', 'attention_mask'],
        num_rows: 8734
    })
    test: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1093
    })
    valid: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1094
    })
    validation: Dataset({
        features: ['captions', 'image', 'input_ids', 'attention_mask'],
        num_rows: 1094
    })
})





In [9]:


import torch

def collate_fn(examples):
    pixel_values = torch.stack([ex["pixel_values"] for ex in examples])
    input_ids = torch.tensor([ex["input_ids"] for ex in examples], dtype=torch.long)
    attention_mask = torch.tensor([ex["attention_mask"] for ex in examples], dtype=torch.long)
    return {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask, "return_loss": True}

In [10]:

from peft import PeftConfig

# base CLIP
base_model = CLIPModel.from_pretrained(MODEL_NAME)
base_model.to(device)

# Freeze original params
for p in base_model.parameters():
    p.requires_grad = False


lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
#     task_type="FEATURE_EXTRACTION",
)


model = get_peft_model(base_model, lora_config)
model.to(device)


def print_trainable(m):
    t = sum(p.numel() for p in m.parameters() if p.requires_grad)
    total = sum(p.numel() for p in m.parameters())
    print(f"Trainable params: {t} ({t/total:.4%} of total)")

print_trainable(model)

Trainable params: 5898240 (3.7526% of total)


In [11]:

def preprocess(example):
    # image → pixel_values
    image = example["image"]
    # Keep tensor, remove [0]
    example["pixel_values"] = image_processor(image, return_tensors="pt")["pixel_values"].squeeze(0)

    # caption → input_ids + attention_mask
    text = example["captions"][0] if isinstance(example["captions"], list) else example["captions"]
    text_tokens = tokenizer(text, truncation=True, padding="max_length", max_length=77)
    example["input_ids"] = torch.tensor(text_tokens["input_ids"], dtype=torch.long)
    example["attention_mask"] = torch.tensor(text_tokens["attention_mask"], dtype=torch.long)

    return example

    
print(dataset)
dataset = dataset.map(preprocess, batched=False, load_from_cache_file=False)
print(dataset["train"][0].keys())

# from transformers import Trainer

# class MyTrainer(Trainer):
#     def compute_loss(self, model, inputs, return_outputs=False):
#         # extract only the keys your CLIPModel accepts
#         pixel_values = inputs.pop("pixel_values")
#         input_ids = inputs.pop("input_ids")
#         attention_mask = inputs.pop("attention_mask")
#         outputs = model(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             pixel_values=pixel_values,
#             return_dict=True
#         )
#         loss = outputs.loss
#         return (loss, outputs) if return_outputs else loss

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    gradient_accumulation_steps=1,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    eval_strategy="steps" if "validation" in dataset else "no",
    fp16=torch.cuda.is_available(),
    push_to_hub=False,
    report_to="none",
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"] if "train" in dataset else None,
    eval_dataset=dataset["validation"],
    data_collator=collate_fn,
)

print(dataset["train"])

DatasetDict({
    train: Dataset({
        features: ['captions', 'image', 'input_ids', 'attention_mask'],
        num_rows: 8734
    })
    test: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1093
    })
    valid: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1094
    })
    validation: Dataset({
        features: ['captions', 'image', 'input_ids', 'attention_mask'],
        num_rows: 1094
    })
})


Map: 100%|█████████████████████████████████████████████████████████████████████████| 8734/8734 [01:16<00:00, 114.70 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████| 1093/1093 [00:08<00:00, 124.92 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████| 1094/1094 [00:09<00:00, 120.08 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████| 1094/1094 [00:09<00:00, 109.40 examples/s]


dict_keys(['captions', 'image', 'input_ids', 'attention_mask', 'pixel_values'])
Dataset({
    features: ['captions', 'image', 'input_ids', 'attention_mask', 'pixel_values'],
    num_rows: 8734
})


In [12]:

train_result = None
if "train" in dataset:
    train_result = trainer.train()
    trainer.save_model(OUTPUT_DIR) 
    tokenizer.save_pretrained(OUTPUT_DIR)
    image_processor.save_pretrained(OUTPUT_DIR)

    print("Training finished. Metrics:")
    print(train_result.metrics)


if "validation" in dataset:
    eval_metrics = trainer.evaluate()
    print("Evaluation metrics:", eval_metrics)



Step,Training Loss,Validation Loss
100,0.7045,2.526394
200,0.5087,2.479175
300,0.4315,2.435039
400,0.4063,2.359408
500,0.3836,2.392176
600,0.3595,2.419033
700,0.4069,2.30982
800,0.3157,2.368917
900,0.3021,2.493088
1000,0.3183,2.426887




Training finished. Metrics:
{'train_runtime': 4268.1647, 'train_samples_per_second': 6.139, 'train_steps_per_second': 0.768, 'total_flos': 1595491786934892.0, 'train_loss': 0.28273216621343034, 'epoch': 3.0}




Evaluation metrics: {'eval_loss': 2.6055562496185303, 'eval_runtime': 63.5454, 'eval_samples_per_second': 17.216, 'eval_steps_per_second': 2.156, 'epoch': 3.0}


In [13]:
from pathlib import Path
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

model.save_pretrained(OUTPUT_DIR)
print("Saved LoRA adapters to:", OUTPUT_DIR)

Saved LoRA adapters to: ./clip-lora-output


In [None]:
from transformers import CLIPModel, CLIPProcessor
from peft import PeftModel

# load base
base = CLIPModel.from_pretrained(MODEL_NAME)
# load adapter onto base 
peft = PeftModel.from_pretrained(base, OUTPUT_DIR)
peft.to(device)
proc = CLIPProcessor.from_pretrained(MODEL_NAME)

img = Image.open("path_to_a_test_image.jpg").convert("RGB") 
inputs = proc(text=["a cat", "a dog"], images=img, return_tensors="pt", padding=True).to(device)
out = peft(**inputs)
print("Logits per image:", out.logits_per_image)