## Import Essential Library And Get Data Ready

In [None]:
import os
import re
import random
from PIL import Image
import torch
from tqdm.notebook import tqdm
from transformers import CLIPModel,CLIPProcessor
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from peft import get_peft_model, LoraConfig, TaskType


In [30]:
def load_split_file(file_path):
    with open(file_path,'r') as f:
        lines= f.readlines()
    samples = []
    for line in lines:
        path, label= line.strip().split()
        samples.append((path, int(label)))
    return samples

def load_prompt(file_path):
    id_to_name = {}
    with open(file_path, 'r') as f:
        for line in f:
            idx, name = line.split()[0],re.sub(r'\d+', '', line).strip()

            id_to_name[int(idx)] = name
    return id_to_name

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [31]:
dataset_root = '../ip102_v1.1/'
train_txt = os.path.join(dataset_root, "train.txt")
val_txt = os.path.join(dataset_root, "val.txt")
test_txt = os.path.join(dataset_root, "test.txt")
images_root=os.path.join(dataset_root, "images")


train_data = load_split_file(train_txt)
test_data = load_split_file(test_txt)
val_data = load_split_file(val_txt)

In [49]:
class CustomImageDataset(Dataset):
    def __init__(self, data_list, root_dir, prompts, transform=None):
        self.data_list = data_list
        self.root_dir = root_dir    
        self.prompts = prompts
        self.transform = transform

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        path, label = self.data_list[idx]
        image_path = os.path.join(self.root_dir, path)
        image = Image.open(image_path).convert("RGB")
        prompt = self.prompts[label]

        if self.transform:
            image = self.transform(image)

        return image,prompt


In [78]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

prompts= list(load_prompt('../large-multi-modal/caption generation/gemini-short.txt').values())

train_dataset = CustomImageDataset(train_data, images_root,prompts, transform=transform,)
val_dataset = CustomImageDataset(val_data, images_root,prompts, transform=transform)
test_dataset = CustomImageDataset(test_data, images_root,prompts, transform=transform)

In [None]:
model_name = "openai/clip-vit-base-patch16"
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

In [None]:

# فقط LoRA روی image encoder (ViT)
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],  # فقط روی attention
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION  # چون فقط encoding تصویر/متن می‌خوایم
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


In [None]:
def collate_fn(batch):
    images, texts = zip(*batch)
    inputs = processor(text=list(texts), images=list(images), return_tensors="pt", padding=True)
    return inputs

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

for batch in train_loader:
    inputs = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**inputs)

    logits_per_image = outputs.logits_per_image  # (B, B)
    labels = torch.arange(len(logits_per_image)).to(device)
    
    loss = criterion(logits_per_image, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
