In [1]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-pke97o_a
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-pke97o_a
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import os
import random
from google.colab import drive
import clip
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader

In [3]:
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [4]:
%cd '/content/gdrive/My Drive/data'

/content/gdrive/My Drive/data


In [5]:
def set_random_seed(seed=28):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seed(0)

# Load CLIP model and preprocessing pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [6]:
class OneShotDataset(Dataset):
    def __init__(self, root_dir, transform=None, one_shot=True):
        """
        Args:
            root_dir (str): Directory containing image folders for each class.
            transform (callable, optional): A function/transform to apply to images.
            one_shot (bool): Whether to select only one training sample per class.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.one_shot = one_shot
        self.data = []
        self.classes = []  # List of class names
        self.class_to_idx = {}
        self._prepare_data()

    def _prepare_data(self):
        """Prepare the dataset by iterating through class directories."""
        self.classes = sorted(os.listdir(self.root_dir))  # Get all class folders
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        for cls_name in self.classes:
            class_path = os.path.join(self.root_dir, cls_name)
            images = sorted(os.listdir(class_path))  # All images in the class folder

            if self.one_shot:
                # Randomly select only one image per class
                images = [random.choice(images)]

            for img_name in images:
                self.data.append((os.path.join(class_path, img_name), self.class_to_idx[cls_name]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

In [7]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Random crop and resize
    transforms.RandomHorizontalFlip(),  # Horizontal flip
    #transforms.RandomRotation(10),
    #transforms.ColorJitter(brightness=0.1, contrast=0.1),# Random rotation
    preprocess  # Apply CLIP preprocessing
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Random crop and resize
    transforms.RandomHorizontalFlip(),  # Horizontal flip
    #transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.1, contrast=0.1), # Random rotation
    preprocess  # Apply CLIP preprocessing
])

# Training dataset: one image per class
train_dataset = OneShotDataset(root_dir="/content/gdrive/My Drive/data/train", transform=train_transform, one_shot=True)

# Testing dataset: all images
test_dataset = OneShotDataset(root_dir="/content/gdrive/My Drive/data/test", transform=test_transform, one_shot=False)

In [8]:
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [9]:
class CLIPClassifier(nn.Module):
    def __init__(self, clip_model, num_classes):
        super(CLIPClassifier, self).__init__()
        self.clip_model = clip_model
        self.dropout = nn.Dropout(0.4)  # Drop 20% neurons
        self.fc = nn.Linear(clip_model.visual.output_dim, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.clip_model.encode_image(x)

        features = features.float()
        features = self.dropout(features)  # Apply dropout
        return self.fc(features)

In [10]:
def train_clip_classifier(model, train_loader, test_loader, epochs=20):
    # Define optimizer with different learning rates for classifier head and CLIP backbone
    optimizer = torch.optim.Adam([
        {'params': model.fc.parameters(), 'lr': 1e-3},  # Classifier head
        {'params': model.clip_model.visual.parameters(), 'lr': 1e-5}  # CLIP backbone
    ], weight_decay=1e-4)  # Add weight decay for regularization

    # Define scheduler to reduce learning rate dynamically
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.7)

    # Define loss function
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct, total = 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Track metrics
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {total_loss / len(train_loader):.4f}, Accuracy: {train_accuracy:.2f}%")

        # Update learning rate scheduler
        scheduler.step()

        # Evaluate on test set after each epoch
        evaluate_clip_classifier(model, test_loader)

def evaluate_clip_classifier(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")


In [11]:
# Initialize model
num_classes = len(train_dataset.classes)
model = CLIPClassifier(clip_model, num_classes).to(device)

# Unfreeze CLIP backbone for fine-tuning
for param in model.clip_model.visual.parameters():
    param.requires_grad = True  # Allow gradient updates for the backbone

# Train the model
train_clip_classifier(model, train_loader, test_loader, epochs=20)

Epoch [1/20] - Loss: 2.9983, Accuracy: 0.00%
Test Accuracy: 12.75%
Epoch [2/20] - Loss: 2.6522, Accuracy: 18.75%
Test Accuracy: 21.50%
Epoch [3/20] - Loss: 2.2469, Accuracy: 50.00%
Test Accuracy: 36.00%
Epoch [4/20] - Loss: 1.9493, Accuracy: 68.75%
Test Accuracy: 49.00%
Epoch [5/20] - Loss: 1.6321, Accuracy: 93.75%
Test Accuracy: 56.25%
Epoch [6/20] - Loss: 1.3895, Accuracy: 100.00%
Test Accuracy: 63.25%
Epoch [7/20] - Loss: 1.1174, Accuracy: 100.00%
Test Accuracy: 65.00%
Epoch [8/20] - Loss: 0.9654, Accuracy: 100.00%
Test Accuracy: 66.75%
Epoch [9/20] - Loss: 0.8616, Accuracy: 100.00%
Test Accuracy: 66.25%
Epoch [10/20] - Loss: 0.7080, Accuracy: 100.00%
Test Accuracy: 67.75%
Epoch [11/20] - Loss: 0.7051, Accuracy: 100.00%
Test Accuracy: 68.00%
Epoch [12/20] - Loss: 0.5766, Accuracy: 100.00%
Test Accuracy: 67.75%
Epoch [13/20] - Loss: 0.5033, Accuracy: 100.00%
Test Accuracy: 67.75%
Epoch [14/20] - Loss: 0.4867, Accuracy: 100.00%
Test Accuracy: 69.25%
Epoch [15/20] - Loss: 0.4041, Accur

In [12]:
torch.save(model.state_dict(), 'OneShotwithCLIP_model_v2.0.pth')