# Setup

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/My\ Drive/CLIPProj

Mounted at /content/drive
/content/drive/My Drive/CLIPProj


In [2]:
# Install dependencies
# !pip freeze | grep tqdm  # check existing libraries
!pip install -r requirements.txt

Collecting lightning==2.5.1 (from -r requirements.txt (line 1))
  Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)
Collecting python-dotenv==1.1.0 (from -r requirements.txt (line 2))
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting wandb==0.19.9 (from -r requirements.txt (line 3))
  Downloading wandb-0.19.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning==2.5.1->-r requirements.txt (line 1))
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning==2.5.1->-r requirements.txt (line 1))
  Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning==2.5.1->-r requirements.txt (line 1))
  Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)
Collecting docker-pycreds>=0.4.0 (from wandb==0.19.9->-r requirements.txt (line 3))
  Dow

In [23]:
# Library Imports
import os
from dotenv import load_dotenv

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, Dataset, random_split
import torch_xla.core.xla_model as xm
import torchvision
from torchvision import transforms

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from tqdm import tqdm

from transformers import CLIPProcessor, CLIPVisionModel

In [4]:
# General Setup
RAND_SEED = 42
load_dotenv()
L.seed_everything(RAND_SEED)

# Use Colab TPU
device = xm.xla_device()
device

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


device(type='xla', index=0)

# Data

## Downloading Data

In [12]:
class ColorFix(object):
    def __call__(self, img):
        if img.mode == 'L':
            img = img.convert("RGB")
        return img

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        ColorFix(),
        transforms.ToTensor(),
    ]
)

In [13]:
# Caltech 101 Data
data = torchvision.datasets.Caltech101('data', transform=transform, download=True)

Downloading...
From (original): https://drive.google.com/uc?id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
From (redirected): https://drive.usercontent.google.com/download?id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp&confirm=t&uuid=56fc2aa1-7238-48f6-ac95-d296c3a4cbb1
To: /content/drive/MyDrive/CLIPProj/data/caltech101/101_ObjectCategories.tar.gz
100%|██████████| 132M/132M [00:00<00:00, 168MB/s]
Downloading...
From: https://drive.google.com/uc?id=175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m
To: /content/drive/MyDrive/CLIPProj/data/caltech101/Annotations.tar
100%|██████████| 14.0M/14.0M [00:00<00:00, 151MB/s]


## CLIP Embeddings

In [14]:
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

In [16]:
dataloader = DataLoader(data, batch_size=64, shuffle=False, num_workers=4, drop_last=True)
embeddings = []
labels = []

for imgs, labs in tqdm(dataloader):
    inputs = clip_processor(images=imgs, return_tensors="pt", padding=True, do_rescale=False).to(device)
    with torch.no_grad():
        outputs = clip_model(**inputs).pooler_output
        # curr = outputs.cpu()
    embeddings.append(outputs)
    labels.append(labs)

embeddings = torch.cat(embeddings)
labels = torch.cat(labels)

100%|██████████| 135/135 [01:08<00:00,  1.98it/s]


In [19]:
# Custom Dataset to Hold Embeddings
class EmbeddingData(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        return self.embeddings[i], self.labels[i]

In [20]:
data = EmbeddingData(embeddings, labels)

## Splitting Data

In [21]:
TRAIN_RATIO, VAL_RATIO = 0.7, 0.15

data_n = len(embeddings)
train_n = int(TRAIN_RATIO * data_n)
val_n = int(VAL_RATIO * data_n)
test_n = data_n - train_n - val_n

In [22]:
train_data, val_data, test_data = random_split(
    data, [train_n, val_n, test_n], generator=torch.Generator().manual_seed(RAND_SEED)
)

In [24]:
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

# Modeling

# Results

In [None]:
from lightning.pytorch.loggers import WandbLogger
import wandb