## Downloads

In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!gdown 1xijq32XfEm6FPhUb7RsZYWHc2UuwVkiq
!tar -xf /content/refcocog.tar.gz
!pip install -qr https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt

In [None]:
!gdown 1j-MGd-pbkppPiYYEYOyZ6vioTEoF-v_g
!mv /content/adapters.py /usr/local/lib/python3.10/dist-packages/clip/
!gdown 1C-h4h7pAkXR9-MhbBjXKLXvh9OwhXtOu
!mv /content/model.py /usr/local/lib/python3.10/dist-packages/clip/
!gdown 18kdAcm8P3GVgDp7GQfIaBg0L3ZeijK3v

In [3]:
import clip
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw

from RefcocogDataset import RefcocogDataset
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loss function

In [4]:
def visualize_sample(sample, bbox, idx=0):
    print(f"Sentence: {sample['sentences'][idx]}")
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(sample['image'][idx].permute(1, 2, 0))
    axes[1].imshow(bbox['gt'][idx])
    plt.tight_layout()
    plt.show()


def visualize_loss(map, bbox, idx, loss_map):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(map)
    axs[1].imshow(bbox['gt'][idx])
    axs[2].imshow(loss_map.reshape(14, 14))
    plt.tight_layout()
    plt.show()

class BatchLossFunction(nn.Module):
    def __init__(self, gamma=3.4, average=True):
        super(BatchLossFunction, self).__init__()
        self.gamma = gamma
        self.average = average

    def forward(self, patch_tokens, out_text, gt):
        loss = torch.zeros(1, requires_grad=True)
        for idx in range(patch_tokens.shape[0]):
            pt = patch_tokens[idx, 1:]
            ot = out_text[idx, :].unsqueeze(0)
            map = torch.zeros(196)

            for i, token in enumerate(pt):
                map[i] = 1 - torch.cosine_similarity(token, ot).item() # 1 - ... temporary fix

            vector = torch.sigmoid(map)

            gt_map = gt[idx]/255
            gt_vector = gt_map.reshape(-1)

            abs = torch.abs(vector - gt_vector)
            log = -torch.log(1-abs)

            # amplify the error of pixels that should belong to the object
            log = log*(gt_vector*self.gamma+1)
            loss = loss + torch.sum(log)

        if self.average:
            return (loss / patch_tokens.shape[0])
        else:
            return loss

## Training loop

In [5]:
from datetime import datetime
from tqdm import tqdm
import os

def train_one_epoch(epoch_index, train_loader, model, criterion, optimizer, loop):
    epoch_losses = []
    for i, (samples, bbox) in enumerate(train_loader):
        loop.set_postfix_str(f'Batch {i+1}/{len(train_loader)}')

        optimizer.zero_grad()

        images = samples['image'].to(device)
        sentences = clip.tokenize(samples['sentences']).to(device)

        out_image, out_text, patch_tokens, text_tokens = model.encode(images, sentences)

        batch_loss = criterion(patch_tokens, out_text, bbox['gt'])

        batch_loss.backward()
        optimizer.step()

        epoch_losses.append(batch_loss)

    return torch.mean(torch.tensor(epoch_losses)).item()

def train_loop(num_epochs, train_loader, model, criterion, optimizer, scheduler, eval_loader = None):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    run_path = 'runs/run_{}'.format(timestamp)
    cmd = f'mkdir runs; mkdir runs/run_{timestamp}'
    os.system(cmd)

    best_eval_loss = float('inf')

    loop = tqdm(range(num_epochs), desc="Training locator", leave=True)
    for epoch in loop:
        model.train()
        epoch_loss = train_one_epoch(epoch, train_loader, model, criterion, optimizer, loop)
        scheduler.step()

        if eval_loader is not None:
            model.eval()
            eval_losses = []
            with torch.no_grad():
                for samples, bbox in eval_loader:
                    images = samples['image'].to(device)
                    sentences = clip.tokenize(samples['sentences']).to(device)
                    out_image, out_text, patch_tokens, text_tokens = model.encode(images, sentences)
                    batch_loss = criterion(patch_tokens, out_text, bbox['gt'])
                    eval_losses.append(batch_loss)

                eval_loss = torch.mean(torch.tensor(eval_losses)).item()
                torch.save(model.state_dict(), run_path + "/last.pt")

                if eval_loss < best_eval_loss:
                    best_eval_loss = eval_loss
                    torch.save(model.state_dict(), run_path + "/best.pt")


## Main

In [6]:
model, preprocess = clip.load("ViT-B/16") # only works with ViT-B/16
model.init_adapters() # needed because state dict of clip does not contain adapters, goes before moving to gpu
# model.load_parameters() # for when we have state dict of adapters trained, goes after adapters init
model = model.to(device)

model.freeze_for_training() # freezes all clip by putting requires_grad=False and then unfreezes adapters

100%|███████████████████████████████████████| 335M/335M [00:05<00:00, 65.5MiB/s]


In [7]:
batch_size = 32 # 32 should be possible
train_dataset = RefcocogDataset("./refcocog", split="train", transform=preprocess)
val_dataset = RefcocogDataset("./refcocog", split="val", transform=preprocess)
test_dataset = RefcocogDataset("./refcocog", split="test", transform=preprocess)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) # switch to shuffle True
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
learning_rate = 5e-5
weight_decay = 5e-3
num_epochs = 60 # 60

criterion = BatchLossFunction(gamma=3.4, average=True) # keep 3.4 for now
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=num_epochs)


train_loop(num_epochs, val_loader, model, criterion, optimizer, scheduler, val_loader) # switch first val_loader to train_loader

Training locator:   3%|▎         | 2/60 [19:13<9:17:36, 576.83s/it, Batch 141/153]


KeyboardInterrupt: ignored