In [52]:
# Set up automatic reloading
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [53]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import transformers
from transformers import CLIPModel
import datasets
from datasets import load_dataset
from dataset import FlickrDataset
from decoder import Decoder

from utils import DEVICE


In [54]:
# dataset = load_dataset("nlphuji/flickr30k", cache_dir="./data")
# train_data = dataset.filter(lambda x: x["split"] == "train")
# test_data = dataset.filter(lambda x: x["split"] == "test")
# valid_data = dataset.filter(lambda x: x["split"] == "val")
# relevant_columns = ["image", "caption"]
# train_dataset = train_data.select_columns(relevant_columns)
# test_dataset = test_data.select_columns(relevant_columns)
# valid_dataset = valid_data.select_columns(relevant_columns)



In [60]:
# Load the dataset & dataloader
dataset = FlickrDataset()
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# Load the model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Freeze the CLIP model
for param in clip_model.vision_model.parameters():
    param.requires_grad = False
for param in clip_model.text_model.parameters():
    param.requires_grad = False

# Initialize the decoder
decoder = Decoder(num_layers=1, embedding_dim=512, num_heads=4, ff_dim=1024).to(DEVICE)

# Optimizer
optimizer = optim.Adam(decoder.parameters(), lr=0.001)


In [56]:
dataset[0][1].shape

torch.Size([77])

In [61]:
num_epochs = 1

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch_idx, batch in enumerate(dataloader):
    # batch = next(iter(dataloader))
        images, captions = batch
        images = images.to(DEVICE)
        captions = captions.to(DEVICE)

        clip_model.eval()
        with torch.no_grad():
            image_features = clip_model.get_image_features(pixel_values=images)
            text_features = clip_model.text_model(
                input_ids=captions,
                output_hidden_states=True
            )
            last_hidden_state = text_features.last_hidden_state  # shape: (batch size, sequence length, hidden size)
            # print(last_hidden_state.shape, image_features.shape)

        # Zero the gradients
        optimizer.zero_grad()

        # Concatenate the image features and the last hidden state
        decoder_input = torch.cat((image_features.unsqueeze(1), last_hidden_state), dim=1)

        # Pass the concatenated features to the decoder
        logits = decoder(decoder_input)

        # Project the decoder output to the vocabulary size
        vocab_projection = nn.Linear(512, dataset.vocab_size).to(DEVICE)

        # Project the logits to the vocabulary size
        logits = vocab_projection(logits)

        # Remove the CLS token
        logits = logits[:, 1:, :].contiguous()

        # Flatten the logits and targets
        logits = logits.view(-1, logits.size(-1))
        targets = captions.view(-1)

        # Compute the loss
        loss = nn.CrossEntropyLoss()(logits, targets)
        epoch_loss += loss.item()

        # Backpropagate the loss
        loss.backward()
        optimizer.step()

        if batch_idx % 25 == 0:
            print(f"Epoch {epoch + 1}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")
    epoch_loss /= len(dataloader)
    print(f"Epoch {epoch + 1} completed, Epoch loss: {epoch_loss:.4f}")


        # decoder_input = decoder_input.view(decoder_input.size(0), -1, decoder_input.size(1))
        # decoder_input.shape




Epoch 1, Batch 1/243, Loss: 10.6508
Epoch 1, Batch 26/243, Loss: 11.7485
Epoch 1, Batch 51/243, Loss: 11.1041
Epoch 1, Batch 76/243, Loss: 11.2058
Epoch 1, Batch 101/243, Loss: 11.0673
Epoch 1, Batch 126/243, Loss: 10.5601
Epoch 1, Batch 151/243, Loss: 10.4117
Epoch 1, Batch 176/243, Loss: 11.1671
Epoch 1, Batch 201/243, Loss: 10.9967
Epoch 1, Batch 226/243, Loss: 10.6263


KeyboardInterrupt: 

In [58]:
234*7/60







27.3