In [1]:
# Set up automatic reloading
%load_ext autoreload
%autoreload 2


In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import transformers
from transformers import CLIPModel
import datasets
from datasets import load_dataset
from dataset import FlickrDataset
from decoder import Decoder
from tqdm import tqdm

from utils import DEVICE


In [5]:
# dataset = load_dataset("nlphuji/flickr30k", cache_dir="./data")
# train_data = dataset.filter(lambda x: x["split"] == "train")
# test_data = dataset.filter(lambda x: x["split"] == "test")
# valid_data = dataset.filter(lambda x: x["split"] == "val")
# relevant_columns = ["image", "caption"]
# train_dataset = train_data.select_columns(relevant_columns)
# test_dataset = test_data.select_columns(relevant_columns)
# valid_dataset = valid_data.select_columns(relevant_columns)



In [16]:
# Load the dataset & dataloader
dataset = FlickrDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [17]:
# Load the model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Freeze the CLIP model
for param in clip_model.vision_model.parameters():
    param.requires_grad = False
for param in clip_model.text_model.parameters():
    param.requires_grad = False

# Initialize the decoder
decoder = Decoder(num_layers=1, embedding_dim=64, num_heads=4, ff_dim=128).to(DEVICE)

# Initialize the vocabulary projection
vocab_projection = nn.Linear(64, dataset.vocab_size).to(DEVICE)

# Optimizer
optimizer = optim.Adam(list(decoder.parameters()) + list(vocab_projection.parameters()), lr=0.001)

# Schedule learning rate
scheduler = optim.NoamLR(optimizer, d_model=64, factor=1.0, warmup=0.1)

In [22]:
num_epochs = 1

for epoch in range(num_epochs):
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=True)
    for batch_idx, batch in enumerate(dataloader):
    # batch = next(iter(dataloader))
        images, captions = batch
        images = images.to(DEVICE)
        captions = captions.to(DEVICE)

        clip_model.eval()
        with torch.no_grad():
            image_features = clip_model.get_image_features(pixel_values=images)
            text_features = clip_model.text_model(
                input_ids=captions,
                output_hidden_states=True
            )
            last_hidden_state = text_features.last_hidden_state  # shape: (batch size, sequence length, hidden size)
            # print(last_hidden_state.shape, image_features.shape)

        # Zero the gradients
        optimizer.zero_grad()

        # Concatenate the image features and the last hidden state
        decoder_input = torch.cat((image_features.unsqueeze(1), last_hidden_state), dim=1)

        # Pass the concatenated features to the decoder
        logits = decoder(decoder_input)

        # Project the logits to the vocabulary size
        logits = vocab_projection(logits)

        # Remove the CLS token
        logits = logits[:, 1:, :].contiguous()

        # Flatten the logits and targets
        logits = logits.view(-1, logits.size(-1))
        targets = captions.view(-1)

        # Compute the loss
        loss = nn.CrossEntropyLoss()(logits, targets)
        epoch_loss += loss.item()

        # Backpropagate the loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        if batch_idx % 25 == 0:
            print(f"Epoch {epoch + 1}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")
    epoch_loss /= len(dataloader)
    print(f"Epoch {epoch + 1} completed, Epoch loss: {epoch_loss:.4f}")





Epoch 1/1:   0%|          | 0/970 [00:00<?, ?it/s]

Epoch 1, Batch 1/970, Loss: 10.8284
Epoch 1, Batch 26/970, Loss: 4.9279
Epoch 1, Batch 51/970, Loss: 2.5473
Epoch 1, Batch 76/970, Loss: 1.7243
Epoch 1, Batch 101/970, Loss: 1.3994
Epoch 1, Batch 126/970, Loss: 1.6091
Epoch 1, Batch 151/970, Loss: 1.4810
Epoch 1, Batch 176/970, Loss: 1.2351
Epoch 1, Batch 201/970, Loss: 1.1416
Epoch 1, Batch 226/970, Loss: 1.1932
Epoch 1, Batch 251/970, Loss: 1.0030
Epoch 1, Batch 276/970, Loss: 1.0143
Epoch 1, Batch 301/970, Loss: 0.9610
Epoch 1, Batch 326/970, Loss: 0.7091
Epoch 1, Batch 351/970, Loss: 0.8895
Epoch 1, Batch 376/970, Loss: 0.7167
Epoch 1, Batch 401/970, Loss: 0.8291
Epoch 1, Batch 426/970, Loss: 0.8426
Epoch 1, Batch 451/970, Loss: 0.7285
Epoch 1, Batch 476/970, Loss: 0.6392
Epoch 1, Batch 501/970, Loss: 0.5643
Epoch 1, Batch 526/970, Loss: 0.5267
Epoch 1, Batch 551/970, Loss: 0.6886
Epoch 1, Batch 576/970, Loss: 0.5561
Epoch 1, Batch 601/970, Loss: 0.5511
Epoch 1, Batch 626/970, Loss: 0.4790
Epoch 1, Batch 651/970, Loss: 0.5710
Epoch

In [58]:
234*7/60







27.3