In [28]:
from main import get_args, eval
from data.transformations import make_train_transforms, val_transforms
import clip
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch

In [29]:
from data.winoground import WinogroundDataset

In [30]:
train = WinogroundDataset(transform=make_train_transforms(None, None), split="train", ratio=0.5)
val = WinogroundDataset(transform=val_transforms, split="test", ratio=0.5)
test = WinogroundDataset(transform=val_transforms, split="test", ratio=1.0)

Found cached dataset winoground (/home/samuelyu/.cache/huggingface/datasets/facebook___winoground/default/0.0.0/72585f4d9cd5a28790bb9bc2adbdd45633f36dfbf85df529e0756e114e134285)
100%|██████████| 1/1 [00:00<00:00, 1321.04it/s]
Found cached dataset winoground (/home/samuelyu/.cache/huggingface/datasets/facebook___winoground/default/0.0.0/72585f4d9cd5a28790bb9bc2adbdd45633f36dfbf85df529e0756e114e134285)
100%|██████████| 1/1 [00:00<00:00, 1346.05it/s]
Found cached dataset winoground (/home/samuelyu/.cache/huggingface/datasets/facebook___winoground/default/0.0.0/72585f4d9cd5a28790bb9bc2adbdd45633f36dfbf85df529e0756e114e134285)
100%|██████████| 1/1 [00:00<00:00, 1321.04it/s]


In [32]:
train_dataloader = DataLoader(train, batch_size=2, shuffle=False)
val_dataloader = DataLoader(val, batch_size=2, shuffle=False)
test_dataloader = DataLoader(test, batch_size=2, shuffle=False)

In [57]:
model, preprocess = clip.load("ViT-B/32", device="cuda")
model = model.float()

for name, param in model.named_parameters():
        # if "11" in name or "ln_post" in name or "ln_final" in name:
        if "position" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

In [58]:
loss_image = nn.CrossEntropyLoss()
loss_text = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-8)
epoch = 0

In [62]:
eval(val_dataloader, model, loss_image, loss_text, None)

(181.6078570218524, 48, 130, 400)

In [63]:
eval(test_dataloader, model, loss_image, loss_text, None)

(376.8845534691354, 84, 244, 800)

In [53]:
train_loss = 0
train_correct_image = 0
train_correct_text = 0
for i, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    image = batch['image'].cuda()
    text = batch['text'].cuda().squeeze(1)
    num_image = len(batch['image'])

    image_features, text_features = model(image, text)
    logits_per_image = 100.0 * image_features @ text_features.t()
    logits_per_text = logits_per_image.t()
    loss_i = loss_image(logits_per_image, torch.arange(len(batch['image'])).cuda())
    loss_t = loss_text(logits_per_text, torch.arange(len(batch['image'])).cuda())

    total_loss = (loss_i + loss_t) / 2
    total_loss.backward()
    train_loss += total_loss.item()

    train_correct_text += ((logits_per_image.argmax(dim=1) == torch.arange(len(batch['image'])).cuda()).sum().item() == len(batch['image']))*len(batch['image'])
    train_correct_image += ((logits_per_text.argmax(dim=1) == torch.arange(len(batch['image'])).cuda()).sum().item() == len(batch['image']))*len(batch['image'])

    optimizer.step()

In [61]:
train_loss/400, train_correct_image/400, train_correct_text/400

(0.4999073280068114, 0.52, 0.605)