In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from data import get_datasets, get_dataloaders
from random_texts import CLIPZeroShotClassifier

In [3]:
datasets, classnames = get_datasets(fraction=1e-3)
for name, dataset in datasets.items():
    print(len(dataset))

Loading dataset shards:   0%|          | 0/27 [00:00<?, ?it/s]

120
15
21
51
52
20


In [4]:
baseline_model = CLIPZeroShotClassifier(classnames)
dataloaders = get_dataloaders(datasets, baseline_model.preprocess, batch_size=128)

Zero-shot classifier: 100%|██████████| 345/345 [00:29<00:00, 11.62it/s]


In [6]:
from torch import nn
import torch
from tqdm import tqdm
from pprint import pprint


def evaluate(model: nn.Module) -> dict[str, float]:
    model.eval()
    results = {}
    with torch.inference_mode():
        for name, dataloader in dataloaders.items():
            correct = 0
            total = 0
            for batch in tqdm(dataloader, desc=f"Evaluating {name}"):
                images = batch["image"]
                labels = batch["label"]
                logits = model(images)
                correct += (logits.argmax(dim=-1) == labels).float().sum()
                total += len(labels)
            results[name] = correct / total
    return results

## Zero-shot model

In [6]:
baseline_results = evaluate(baseline_model)
pprint(baseline_results)

Evaluating ID: 100%|██████████| 1/1 [00:07<00:00,  7.69s/it]
Evaluating OOD_infograph: 100%|██████████| 1/1 [00:00<00:00,  1.09it/s]
Evaluating OOD_painting: 100%|██████████| 1/1 [00:01<00:00,  1.23s/it]
Evaluating OOD_quickdraw: 100%|██████████| 1/1 [00:02<00:00,  2.61s/it]
Evaluating OOD_real: 100%|██████████| 1/1 [00:03<00:00,  3.06s/it]
Evaluating OOD_clipart: 100%|██████████| 1/1 [00:01<00:00,  1.45s/it]

{'ID': tensor(0.8417),
 'OOD_clipart': tensor(0.6500),
 'OOD_infograph': tensor(0.4000),
 'OOD_painting': tensor(0.7619),
 'OOD_quickdraw': tensor(0.1373),
 'OOD_real': tensor(0.8269)}





## Full fine-tuning

In [7]:
ft_model = CLIPZeroShotClassifier(classnames)

100%|██████████| 345/345 [00:25<00:00, 13.67it/s]


In [8]:
optimizer = torch.optim.AdamW(ft_model.parameters(), lr=3e-5, weight_decay=0.1)
total_steps = len(dataloaders["ID"])
warmup_steps = min(500, total_steps // 2)
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    total_iters=warmup_steps,
)
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=total_steps - warmup_steps,
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, cosine_scheduler],
    milestones=[warmup_steps],
)
pbar = tqdm(dataloaders["ID"], desc="Fine-tuning")
for batch in pbar:
    images = batch["image"]
    labels = batch["label"]
    logits = ft_model(images)
    loss = nn.functional.cross_entropy(logits, labels)
    pbar.set_postfix(loss=f"{loss.item():.4f}")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Fine-tuning: 100%|██████████| 1/1 [00:25<00:00, 25.70s/it, loss=0.6562]


In [None]:
with open("ft_model.pth", "wb") as f:
    torch.save(ft_model.state_dict(), f)

In [None]:
ft_results = evaluate(ft_model)
pprint(ft_results)
del ft_model

Evaluating ID: 100%|██████████| 1/1 [00:06<00:00,  6.12s/it]
Evaluating OOD_infograph: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
Evaluating OOD_painting: 100%|██████████| 1/1 [00:01<00:00,  1.07s/it]
Evaluating OOD_quickdraw: 100%|██████████| 1/1 [00:02<00:00,  2.57s/it]
Evaluating OOD_real: 100%|██████████| 1/1 [00:02<00:00,  2.63s/it]
Evaluating OOD_clipart: 100%|██████████| 1/1 [00:01<00:00,  1.05s/it]

{'ID': tensor(0.9083),
 'OOD_clipart': tensor(0.6500),
 'OOD_infograph': tensor(0.4000),
 'OOD_painting': tensor(0.7619),
 'OOD_quickdraw': tensor(0.1373),
 'OOD_real': tensor(0.8269)}





## Lipsum-FT

In [7]:
lipsum_model = CLIPZeroShotClassifier(classnames)

Zero-shot classifier: 100%|██████████| 345/345 [00:23<00:00, 14.88it/s]


In [8]:
from clip.clip import _tokenizer
import numpy as np


def sample_random_tokens(n: int, L: int = 8):
    V = len(_tokenizer.encoder)
    return [
        "".join(_tokenizer.decode(np.random.randint(0, V, size=L))) for _ in range(n)
    ]


print(*sample_random_tokens(16), sep="\n")

🙏🏼 lgbti wr ➤ shedlight care surro
kamal ctive tob ii mags proj acea emed 
buzz cur bredyap straight nigel peregravail 
leash memory mach🇷🇺 happens domestdrome daz
pkwy axelkon scary fault dana minus ía 
pion issaseaworld hipsters accommodate processors grumpy lizards 
gzchand adic wec caterpillar charlesdren; 
reminding poppyyash snazzy maz pressing rahuundocumented 
speed sevxalbums umni enthreventoss 
photobomb millennium rift ruewifey icelandic iciff 
machinefiftycoke�satishomerdebates howe 
alliancetins edged nation pilooklahoma leur partner
ceiling dempsey standalone jdm sightings anchored caliber dder 
▪pent ⭐⭐earts wahgenerates tered saver 
frampton bookofbettersickest routines mast⁦⁦@ atoday 
berto saffron ffler sorrows miro 🤩bingashok 


In [10]:
optimizer = torch.optim.AdamW(lipsum_model.parameters(), lr=3e-5, weight_decay=0.1)
total_steps = len(dataloaders["ID"])
warmup_steps = min(500, total_steps // 2)
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    total_iters=warmup_steps,
)
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=total_steps - warmup_steps,
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, cosine_scheduler],
    milestones=[warmup_steps],
)
lambda_lipsum = 0.1
pbar = tqdm(dataloaders["ID"], desc="Fine-tuning")
for batch in pbar:
    images = batch["image"]
    labels = batch["label"]
    texts = sample_random_tokens(len(images))
    logits = lipsum_model(images)
    cur_energy = lipsum_model.get_energy(images, texts)
    with torch.no_grad():
        old_energy = baseline_model.get_energy(images, texts)
    ce_loss = nn.functional.cross_entropy(logits, labels)
    gap_loss = nn.functional.mse_loss(cur_energy, old_energy)
    loss = ce_loss + lambda_lipsum * gap_loss
    pbar.set_postfix(ce_loss=f"{ce_loss.item():.4f}", gap_loss=f"{gap_loss.item():.4f}")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Fine-tuning: 100%|██████████| 1/1 [01:35<00:00, 95.54s/it, ce_loss=0.6431, gap_loss=0.0000]


In [11]:
with open("lipsum_model.pth", "wb") as f:
    torch.save(lipsum_model.state_dict(), f)