In [1]:
from architectures.clip_detector.detector import CLIPDetector
import torch
from PIL import Image
from io import BytesIO
from dataset import RefCocoBatch, RefCocoConfig, RefCocoDataset
from pathlib import Path
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
cfg = RefCocoConfig({
    "path": "../refcocog"
})

train_dataset = RefCocoDataset(config = cfg, phase = "train")
test_dataset = RefCocoDataset(config = cfg, phase = "test")

In [3]:
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=RefCocoDataset.batchify)

In [4]:
import clip
import torch.optim as optim

device = "cuda"
model = CLIPDetector(device=device) 
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

In [5]:
for epoch in range(200):
    for batch in train_dataloader:
        batch_loss = []
        for i, image in tqdm(enumerate(batch.images)):

            image = image.unsqueeze(0).to(device)
            captions_embeddings = model.encode_captions(batch.sentences)
            bbox, img_embedding = model.encode_image_caption(image, batch.sentences[i])

            similarity = (img_embedding @ captions_embeddings.T)
            
            # Gradient accumulation to avoid SGD
            loss = model.contrastive_loss(similarity, 0, i, tau = 1) / batch_size
            loss.backward()

            batch_loss.append(loss)
            
        optimizer.step()
        optimizer.zero_grad()
        print(f"Batch loss: \t{torch.stack(batch_loss).mean()}\n")

16it [00:08,  1.84it/s]


Batch loss: 	0.44282224774360657



16it [00:06,  2.29it/s]


Batch loss: 	0.17085066437721252



16it [00:08,  1.79it/s]


Batch loss: 	0.19278323650360107



1it [00:02,  2.35s/it]


RuntimeError: CUDA out of memory. Tried to allocate 98.00 MiB (GPU 0; 8.00 GiB total capacity; 6.85 GiB already allocated; 0 bytes free; 7.20 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF