In [None]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from open_clip import tokenizer
import tqdm
import random

# OpenCLIP imports
import open_clip

# Path to images and annotations
image_dir = './coco/images/train2017/'  # Path to train2017 images
annotation_file = './coco/annotations/captions_train2017.json'  # Path to train2017 captions

# Define the transform to be applied to the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # or whatever size your model expects
    transforms.ToTensor()
])

# Create the dataset
cap = dset.CocoCaptions(
    root=image_dir,
    annFile=annotation_file,
    transform=transform
)

# Print dataset details
print('Number of samples:', len(cap)) # 118287 images

# Check if every image has a caption
for i in range(5000):
    img, target = cap[i]
    #print("Image Size:", img.size())
    if len(target) == 0:
        print("No caption for image", i)


# Access a specific sample (4th sample here)
img, target = cap[3]  # Load the 4th sample (index 3)

# Display information about the sample
print("Image Size:", img.size())  # Torch tensor size
plt.imshow(img.permute(1, 2, 0))  # Display the image
print("Captions:", target)  # Captions for the image

# Create DataLoader
batch_size = 64  # example
loader = DataLoader(cap, batch_size=batch_size, shuffle=True, drop_last=True)


In [None]:
model_name = "ViT-B-32"        # Example architecture
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"

# Create model & transforms from scratch (no pretrained weights)
model, preprocess, _ = open_clip.create_model_and_transforms(
    model_name,
    pretrained=None,
    device=device
)

# Put the model into training mode
model.train()

# If you want to fine-tune *everything* from scratch, ensure all parameters require grad:
for param in model.parameters():
    param.requires_grad = True


def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    """
    image_embeds: (batch_size, embed_dim)
    text_embeds: (batch_size, embed_dim)
    temperature: scalar float for scaling similarities
    returns: scalar loss (contrastive)
    """
    # Normalize embeddings (optional, but typical in CLIP-like models)
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds  = F.normalize(text_embeds, dim=-1)
    
    # Similarity matrix, shape (bs, bs)
    logits = image_embeds @ text_embeds.t()
    logits = logits / temperature

    # Targets are just the diagonal (i.e. 0->0, 1->1, ...)
    batch_size = image_embeds.size(0)
    target = torch.arange(batch_size, device=logits.device)

    # CE loss for image->text
    loss_i2t = F.cross_entropy(logits, target)
    # CE loss for text->image
    loss_t2i = F.cross_entropy(logits.t(), target)

    # Average the two directions
    return (loss_i2t + loss_t2i) / 2


# Example config
lr = 1e-4
epochs = 1
temperature = 0.07

optimizer = optim.AdamW(model.parameters(), lr=lr)

for epoch in range(epochs):
    for images, captions_list in tqdm.tqdm(loader):
        images = images.to(device)
        
        # For COCO, each item can have multiple captions. 
        # We'll just pick the first caption from each list for this example:
        captions = list(captions_list[0])

        #print("Captions:", captions)
        #print("Captions length:", len(captions))

        # Tokenize text (open_clip tokenizer produces tokenized batch)
        text_tokens = tokenizer.tokenize(captions)
        text_tokens = text_tokens.to(device)

        # Encode image and text
        image_embeds = model.encode_image(images)
        text_embeds  = model.encode_text(text_tokens)

        # Compute the contrastive loss
        loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"[Epoch {epoch+1}/{epochs}]  Loss: {loss.item():.4f}")


In [None]:
for images, captions_list in loader:
    # images.shape is e.g. (N, 3, 224, 224)
    # captions_list has length N, but each item might be a tuple of possible captions

    print("Image batch size:", "COMPLETE:" ,images.shape[0], images.shape)
    print("Captions list length:", len(captions_list[0]))
    
    print("Captions list:", list(captions_list[0]))

    print("Number of chosen captions:", len(list(captions_list[0])))

    # Then tokenize
    text_tokens = tokenizer.tokenize(captions)
    print("Text tokens shape:", text_tokens.shape)

    # Now encode
    image_embeds = model.encode_image(images.to(device))
    text_embeds = model.encode_text(text_tokens.to(device))

    # Should both be shape (N, D)
    print("Image embeds shape:", image_embeds.shape)
    print("Text  embeds shape:", text_embeds.shape)

    break  # just to test one batch


# TODO: ensure that each tuple of captions has the same length, or the data loader will fail (defalut is collate(samples, collate_fn_map=collate_fn_map) from error message)
