In [9]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from open_clip import tokenizer
from torch.utils.data import Subset
import tqdm
import random

# OpenCLIP imports
import open_clip

# Path to train images and annotations
train_image_dir = './coco/images/train2017/'  # Path to train2017 images
train_annotation_file = './coco/annotations/captions_train2017.json'  # Path to train2017 captions

# Path to test (val) images and annotations
test_image_dir = './coco/images/val2017/'  # Path to val2017 images
test_annotation_file = './coco/annotations/captions_val2017.json'  # Path to val2017 captions

# Define the transform to be applied to the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # or whatever size your model expects
    transforms.ToTensor()
])

# Create the training dataset
train_coco = dset.CocoCaptions(
    root=train_image_dir,
    annFile=train_annotation_file,
    transform=transform
)

# Create the test dataset
test_coco = dset.CocoCaptions(
    root=test_image_dir,
    annFile=test_annotation_file,
    transform=transform
)

num_samples = 1000
subset_indices = list(range(num_samples))
train_coco = Subset(train_coco, subset_indices)


# Print dataset details
print('Number of samples:', len(train_coco)) # 118287 images

# Access a specific sample (4th sample here)
img, target = train_coco[3]  # Load the 4th sample (index 3)

# Display information about the sample
print("Image Size:", img.size())  # Torch tensor size
#plt.imshow(img.permute(1, 2, 0))  # Display the image
print("Captions:", target)  # Captions for the image

# Every image has 5 captions at max, we need to sample one of them
# Create collate function to sample one caption per image
def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images, 0)
    sel_captions = []
    for list_captions in captions:
        caption = random.choice(list_captions)
        sel_captions.append(caption)
    return images, sel_captions

# Create DataLoader
batch_size = 128
train_loader = DataLoader(train_coco, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)
test_loader = DataLoader(test_coco, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn)





loading annotations into memory...
Done (t=0.62s)
creating index...
index created!
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
Number of samples: 1000
Image Size: torch.Size([3, 224, 224])
Captions: ['A zebra grazing on lush green grass in a field.', 'Zebra reaching its head down to ground where grass is. ', 'The zebra is eating grass in the sun.', 'A lone zebra grazing in some green grass.', 'a Zebra grazing on grass in a green open field.']


In [6]:
%doctest_mode

for images, captions_list in train_loader:
    # images.shape is e.g. (N, 3, 224, 224)
    # captions_list has length N, but each item might be a tuple of possible captions

    plt.imshow(images[0].permute(1, 2, 0))
    plt.show()
    plt.imshow(images[1].permute(1, 2, 0))
    plt.show()

    print("Image batch size:", images.shape[0], "Shape:", images.shape)
    print("Captions list length:", len(captions_list))
    
    print("Captions list:", list(captions_list))

    print("Number of chosen captions:", len(list(captions_list[0])))
    
    captions = list(captions_list[0])

    # Then tokenize
    text_tokens = tokenizer.tokenize(captions)
    print("Text tokens shape:", text_tokens.shape)

    # Now encode
    #image_embeds = model.encode_image(images.to(device))
    #text_embeds = model.encode_text(text_tokens.to(device))

    # Should both be shape (N, D)
    #print("Image embeds shape:", image_embeds.shape)
    #print("Text  embeds shape:", text_embeds.shape)

    break  # just to test one batch
    

def collate_fn_debug(batch):
    print("Bath type:", type(batch)) # This is a list
    print("Batch size:", len(batch))
    print("Batch:", batch)
    images, captions = zip(*batch)
    
    print("Images type:", type(images))
    print("Images size:", len(images))
    print("Images:", images)
    
    print("Captions type:", type(captions))
    print("Captions size:", len(captions))
    print("Captions:", captions) # This is a tuple of lists, each list contains 5 captions for each image
    
    # Select one caption per image
    sel_captions = []
    for list_captions in captions:
        #print("List Captions:", list_captions)
        caption = random.choice(list_captions)
        sel_captions.append(caption)
    
    print("Selected Captions:", sel_captions)    



for images, captions_list in train_loader:
    break

# DONE: ensure that each tuple of captions has the same length, or the data loader will fail (defalut is collate(samples, collate_fn_map=collate_fn_map) from error message)



Exception reporting mode: Plain
Doctest mode is: ON


<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

Image batch size: 128 Shape: torch.Size([128, 3, 224, 224])
Captions list length: 128
Captions list: ['A man in an office chair looking at a laptop next to a glass of wine.', 'Hot dogs, buns, and croissants on grill with red wall', 'A boat motors past a man playing Frisbee on the beach ', 'An individual is capture in the stillness of the picture.\n', 'Children on beright sunny day playing soccer who appear to be about 5 years old. ', 'Man standing in an auditorium filled with chairs.', 'a living room with a big colorful rug on the floor ', 'A woman kneeling down on top of a baseball field.', 'a small kid looks up at a kite ', 'Gazelles, zebras and giraffes roaming around the plains.', 'Assorted fruit are laying on a cutting board.', 'There is a little toy hanging on the key chain. ', 'A man smiles wearing beads and a neck tie.', 'Fresh produce has been harvested to take to market.', 'Three people standing outsidde a mall in the fall.', 'a very tall giraffe walking in a bush with his ne

In [10]:
# %%prun to profile and see where the time is spent

model_name = "ViT-B-32"        # Example architecture
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"

# Create model & transforms from scratch (no pretrained weights)
model, preprocess, _ = open_clip.create_model_and_transforms(
    model_name,
    pretrained=None,
    device=device
)

# Put the model into training mode
model.train()

# If you want to fine-tune *everything* from scratch, ensure all parameters require grad:
for param in model.parameters():
    param.requires_grad = True


def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    """
    image_embeds: (batch_size, embed_dim)
    text_embeds: (batch_size, embed_dim)
    temperature: scalar float for scaling similarities
    returns: scalar loss (contrastive)
    """
    # Normalize embeddings (optional, but typical in CLIP-like models)
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds  = F.normalize(text_embeds, dim=-1)
    
    # Similarity matrix, shape (bs, bs)
    logits = image_embeds @ text_embeds.t()
    logits = logits / temperature

    # Targets are just the diagonal (i.e. 0->0, 1->1, ...)
    batch_size = image_embeds.size(0)
    target = torch.arange(batch_size, device=logits.device)

    # CE loss for image->text
    loss_i2t = F.cross_entropy(logits, target)
    # CE loss for text->image
    loss_t2i = F.cross_entropy(logits.t(), target)

    # Average the two directions
    return (loss_i2t + loss_t2i) / 2




# Example config
lr = 1e-4
epochs = 1
temperature = 0.07

# Move the model to multiple GPUs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])  # Use 4 GPUs

optimizer = optim.AdamW(model.parameters(), lr=lr)

current_batch = 0

for epoch in range(epochs):
    for images, captions_list in tqdm.tqdm(train_loader):
        
        current_batch += 1
        
        # Move data to the primary device
        images = images.to(device)
        captions = captions_list

        # Tokenize text
        text_tokens = tokenizer.tokenize(captions)
        text_tokens = text_tokens.to(device)

        # Encode image and text
        image_embeds = model.module.encode_image(images)  # Use .module for methods inside DataParallel
        text_embeds = model.module.encode_text(text_tokens)

        # Compute the contrastive loss
        loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"[Epoch {epoch+1}/{epochs}]  Loss: {loss.item():.4f}")


100%|██████████| 7/7 [00:08<00:00,  1.23s/it]


[Epoch 1/1]  Loss: 4.8527
