In [1]:
!pip show torchvision

Name: torchvision
Version: 0.22.0
Summary: image and video datasets and models for torch deep learning
Home-page: https://github.com/pytorch/vision
Author: PyTorch Core Team
Author-email: soumith@pytorch.org
License: BSD
Location: /home/bboulbarss/.local/lib/python3.11/site-packages
Requires: numpy, pillow, torch
Required-by: 


In [2]:
!pip show peft

Name: peft
Version: 0.15.2
Summary: Parameter-Efficient Fine-Tuning (PEFT)
Home-page: https://github.com/huggingface/peft
Author: The HuggingFace team
Author-email: benjamin@huggingface.co
License: Apache
Location: /home/bboulbarss/.local/lib/python3.11/site-packages
Requires: accelerate, huggingface_hub, numpy, packaging, psutil, pyyaml, safetensors, torch, tqdm, transformers
Required-by: 


In [3]:
import os
import random
from PIL import Image
import tqdm
import numpy as np
import pandas as pd
import csv
import time
from datetime import datetime
import pytz

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

from peft import LoraConfig, get_peft_model
from transformers import CLIPProcessor, CLIPModel

In [4]:
from transformers.utils import logging
logging.set_verbosity_error()

import warnings
warnings.filterwarnings("ignore")

os.environ["WANDB_DISABLED"] = "true"

In [5]:
def create_distractors_single_object(true_label):
    shapes = ['cube', 'sphere', 'cone', 'cylinder']
    colors = ['blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow']

    all_labels = [f"A photo of a {color} {shape}" for shape in shapes for color in colors]
    all_labels.remove(true_label)

    random_labels = random.sample(all_labels, k=4)

    return random_labels

In [6]:
def create_distractors_two_object(true_labels):
    shapes = ['cube', 'sphere', 'cone', 'cylinder']
    colors = ['blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow']

    _, _, _, _, color1, shape1 = true_labels[0].split()
    _, _, _, _, color2, shape2 = true_labels[1].split()

    hard_distractors = [f"A photo of a {color1} {shape2}", f"A photo of a {color2} {shape1}"]

    exclude = set(true_labels + tuple(hard_distractors))
    all_labels = [f"A photo of a {color} {shape}" for shape in shapes for color in colors if f"A photo of a {color} {shape}" not in exclude]
    random_labels = random.sample(all_labels, k=4-len(hard_distractors))

    return hard_distractors + random_labels

In [7]:
def create_distractors_relational(true_label):
    shapes = ['cube', 'sphere', 'cone', 'cylinder']
    relations = {'right': 'left', 'left': 'right'}

    true_parts = true_label.split()
    _, _, _, _, true_shape1, true_relation, _, _, true_shape2 = true_parts  # e.g., 'a', 'photo, 'of', 'a', 'sphere', 'right', 'of', 'a', 'cube'

    # Define hard distractors
    # 1. Shape-swapped: Swap true_shape1 and true_shape2
    shape_swapped = f"A photo of a {true_shape2} {true_relation} of a {true_shape1}"
    # 2. Relation-swapped: Use opposite relation
    relation_swapped = f"A photo of a {true_shape1} {relations[true_relation]} of a {true_shape2}"

    hard_distractors = [shape_swapped, relation_swapped]

    # Generate all possible labels
    all_labels = [f"A photo of a {shape} {rel} of a {other_shape}"
                  for shape in shapes
                  for rel in relations
                  for other_shape in shapes if other_shape != shape]

    # Define the inverse label (already true and must be excluded)
    inverse_label = f" A photo of a {true_shape2} {relations[true_relation]} of a {true_shape1}"

    # Filter out true label, inverse label, and ensure hard distractors are unique
    exclude = {true_label, inverse_label}
    filtered_labels = [label for label in all_labels if label not in exclude]

    # Sample random distractors, excluding hard distractors if they’re already in filtered_labels
    random_labels = random.sample([label for label in filtered_labels if label not in hard_distractors], k=4-len(hard_distractors))

    return hard_distractors + random_labels

In [8]:
class CustomDataset(ImageFolder):
    def __init__(self, root, dataset_name, transform=None):
        super().__init__(root, transform=transform)
        self.dataset_name = dataset_name

    def find_classes(self, directory):
        classes = [d.name for d in os.scandir(directory) if d.is_dir() and not d.name.startswith('.')]
        classes.sort()
        if not classes:
            raise FileNotFoundError(f"Couldn't find any valid class folders in {directory}")
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index):
        path, target = self.samples[index]
        image = self.loader(path)  # Load as PIL Image
        correct_label = self.classes[target]

        if self.dataset_name == 'single_object':
            parts = correct_label.split('_')
            correct_label = f'A photo of a {parts[0]} {parts[1]}'
            labels_list = [correct_label] + create_distractors_single_object(correct_label)
        elif self.dataset_name == 'two_object':
            parts = correct_label.split('_')
            correct_label = f"A photo of a {parts[0]} {parts[1]}"
            filler_label = f"A photo of a {parts[2]} {parts[3]}"
            labels_list = [correct_label] + create_distractors_two_object((correct_label, filler_label))
        elif self.dataset_name == 'relational':
            parts = correct_label.split('_')
            correct_label = f"A photo of a {parts[0]} {parts [1]} of a {parts[2]}"
            labels_list = [correct_label] + create_distractors_relational(correct_label)

        random.shuffle(labels_list)
        correct_index = labels_list.index(correct_label)

        return image, labels_list, correct_index

In [9]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPProcessor
from peft import LoraConfig, get_peft_model
from datetime import datetime
import pytz
import matplotlib.pyplot as plt
from itertools import product
import uuid

def train_and_evaluate(dataset_name, base_path='/home/bboulbarss/large_dataset', seed=42):
    # Set seeds for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Define paths for train, validation, and test splits
    train_root = os.path.join(base_path, dataset_name, 'train')
    val_ood_root = os.path.join(base_path, dataset_name, 'ood_val')

    # Rename directories to replace spaces with underscores
    for dir_name in os.listdir(train_root):
        if ' ' in dir_name:
            new_name = dir_name.replace(' ', '_')
            os.rename(
                os.path.join(train_root, dir_name),
                os.path.join(train_root, new_name)
            )
    for dir_name in os.listdir(val_ood_root):
        if ' ' in dir_name:
            new_name = dir_name.replace(' ', '_')
            os.rename(
                os.path.join(val_ood_root, dir_name),
                os.path.join(val_ood_root, new_name)
            )

    # Create datasets
    train_dataset = CustomDataset(root=train_root, dataset_name=dataset_name)
    val_ood_dataset = CustomDataset(root=val_ood_root, dataset_name=dataset_name)

    # Define hyperparameter grid
    batch_size = 32
    lora_rs = [(8, 16), (16, 32)]
    learning_rates = [1e-6, 1e-5]

    # Initialize variables to track the best model
    best_accuracy = 0.0
    best_model_path = None
    best_processor_path = None
    best_hyperparams = None
    best_train_losses = []
    timestamp = datetime.now(pytz.timezone('Europe/Amsterdam')).strftime('%Y%m%d_%H%M%S')

    # Grid search over hyperparameters
    for lora_r, lr in product(lora_rs, learning_rates):
        print(f"\nTesting hyperparameters: lora_r={lora_r}, lr={lr}")

        # Create data loaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True, 
            collate_fn=lambda x: x, 
            num_workers=0, 
            pin_memory=True
        )
        val_ood_loader = DataLoader(
            val_ood_dataset, 
            batch_size=batch_size, 
            shuffle=False, 
            collate_fn=lambda x: x, 
            num_workers=0, 
            pin_memory=True
        )

        # Load the base CLIP model and processor
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        # Apply LoRA
        lora_config = LoraConfig(
            r=lora_r[0],
            lora_alpha=lora_r[1],
            lora_dropout=0.5,
            bias="none",
            target_modules=[
                "self_attn.q_proj",
                "self_attn.k_proj",
                "self_attn.v_proj",
                "self_attn.out_proj",
                "mlp.fc1",
                "mlp.fc2",
                "visual_projection",
                "text_projection"
            ]
        )
        peft_model = get_peft_model(model, lora_config)
        for name, param in peft_model.named_parameters():
            if "lora" not in name.lower():
                param.requires_grad = False
        model = peft_model
        model.to(device)

        # Set up optimizer
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()), 
            lr=lr, 
            weight_decay=0.1
        )

        num_epochs = 15
        patience = 3
        epochs_no_improve = 0
        current_best_accuracy = 0.0
        train_losses = []

        for epoch in range(num_epochs):
            model.train()
            total_loss = 0.0
            for i, batch in enumerate(train_loader):
                images, texts_lists, correct_indices = zip(*batch)
                texts = [text for texts_list in texts_lists for text in texts_list]

                inputs = processor(
                    text=texts,
                    images=images,
                    return_tensors="pt",
                    padding=True,
                    truncation=True
                ).to(device)

                outputs = model(**inputs)
                logits_per_image = outputs.logits_per_image
                B = len(images)
                col_indices = (torch.arange(B, device=device) * 5).unsqueeze(1) + torch.arange(5, device=device)
                relevant_logits = logits_per_image.gather(1, col_indices)

                correct_indices_tensor = torch.tensor(correct_indices, device=device)
                loss = nn.CrossEntropyLoss(label_smoothing=0.1)(relevant_logits, correct_indices_tensor)
                total_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if i % 10 == 0:
                    print(f"Epoch {epoch+1}, Batch {i}, Loss: {loss.item():.4f}")

            avg_loss = total_loss / len(train_loader)
            train_losses.append(avg_loss)
            print(f"Epoch {epoch+1} completed, Average Training Loss: {avg_loss:.4f}")

            # Validation phase
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                print("Start OOD Validation Phase")
                for batch in val_ood_loader:
                    images, texts_lists, correct_indices = zip(*batch)
                    texts = [text for texts_list in texts_lists for text in texts_list]
                    inputs = processor(
                        text=texts,
                        images=images,
                        return_tensors="pt",
                        padding=True,
                        truncation=True
                    ).to(device)

                    outputs = model(**inputs)
                    logits_per_image = outputs.logits_per_image
                    B = len(images)
                    col_indices = (torch.arange(B, device=device) * 5).unsqueeze(1) + torch.arange(5, device=device)
                    relevant_logits = logits_per_image.gather(1, col_indices)

                    preds = relevant_logits.argmax(dim=1)
                    correct += (preds == torch.tensor(correct_indices, device=device)).sum().item()
                    total += B

                accuracy_ood = correct / total
                print(f"Validation (OOD) Accuracy for {dataset_name}: {accuracy_ood:.4f}")

                # Save best model based on OOD validation
                if accuracy_ood > current_best_accuracy:
                    current_best_accuracy = accuracy_ood
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1

                # Update global best model if current model is better
                if accuracy_ood > best_accuracy:
                    best_accuracy = accuracy_ood
                    best_hyperparams = {'batch_size': batch_size, 'lora_r': lora_r, 'lr': lr}
                    best_train_losses = train_losses.copy()
                    model_save_dir = '/home/bboulbarss/finetuned_models/clip'
                    os.makedirs(model_save_dir, exist_ok=True)
                    best_model_path = os.path.join(
                        model_save_dir, 
                        f'clip_lora_best_{dataset_name}_{seed}_{timestamp}_{lora_r}_{batch_size}_{lr}'
                    )
                    best_processor_path = os.path.join(
                        model_save_dir, 
                        f'clip_processor_best_{dataset_name}_{seed}_{timestamp}_{lora_r}_{batch_size}_{lr}'
                    )
                    model.save_pretrained(best_model_path)
                    processor.save_pretrained(best_processor_path)

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Plot training loss curve for the best model
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(best_train_losses) + 1), best_train_losses, marker='o', linestyle='-', color='b')
    plt.title(f'Training Loss Curve for Best Model\nDataset: {dataset_name}, '
              f'Hyperparameters: {best_hyperparams}')
    plt.xlabel('Epoch')
    plt.ylabel('Average Training Loss')
    plt.grid(True)
    plt.savefig(os.path.join(model_save_dir, f'training_loss_curve_{dataset_name}_{seed}_{timestamp}.png'))
    plt.close()

    print(f"\nBest Model Hyperparameters: {best_hyperparams}")
    print(f"Best Validation (OOD) Accuracy: {best_accuracy:.4f}")
    print(f"Best LoRA adapter saved to: {best_model_path}")
    print(f"Processor saved to: {best_processor_path}")
    print(f"Training loss curve saved to: {model_save_dir}/training_loss_curve_{dataset_name}_{seed}_{timestamp}.png")

In [10]:
datasets = ['single_object'] # 'single_object', 'relational', 'two_object'
for dataset in datasets:
    print(f"\nTraining on {dataset}")
    train_and_evaluate(dataset, seed=42)


Training on single_object

Testing hyperparameters: lora_r=(8, 16), lr=1e-06
Epoch 1, Batch 0, Loss: 0.6441
Epoch 1, Batch 10, Loss: 0.5829
Epoch 1, Batch 20, Loss: 0.6299
Epoch 1, Batch 30, Loss: 0.5867
Epoch 1, Batch 40, Loss: 0.6083
Epoch 1 completed, Average Training Loss: 0.6178
Start OOD Validation Phase
Validation (OOD) Accuracy for single_object: 0.9475
Epoch 2, Batch 0, Loss: 0.5915
Epoch 2, Batch 10, Loss: 0.5898
Epoch 2, Batch 20, Loss: 0.6178
Epoch 2, Batch 30, Loss: 0.5938
Epoch 2, Batch 40, Loss: 0.5909
Epoch 2 completed, Average Training Loss: 0.5988
Start OOD Validation Phase
Validation (OOD) Accuracy for single_object: 0.9500
Epoch 3, Batch 0, Loss: 0.6021
Epoch 3, Batch 10, Loss: 0.6383
Epoch 3, Batch 20, Loss: 0.6393
Epoch 3, Batch 30, Loss: 0.5927
Epoch 3, Batch 40, Loss: 0.5849
Epoch 3 completed, Average Training Loss: 0.5814
Start OOD Validation Phase
Validation (OOD) Accuracy for single_object: 0.9500
Epoch 4, Batch 0, Loss: 0.6933
Epoch 4, Batch 10, Loss: 0.596