In [1]:
# Set up automatic reloading
%load_ext autoreload
%autoreload 2


In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import transformers
from transformers import CLIPModel
import datasets
from datasets import load_dataset
from dataset import FlickrDataset
from decoder import Decoder
from tqdm.notebook import tqdm

from utils import DEVICE


In [5]:
# dataset = load_dataset("nlphuji/flickr30k", cache_dir="./data")
# train_data = dataset.filter(lambda x: x["split"] == "train")
# test_data = dataset.filter(lambda x: x["split"] == "test")
# valid_data = dataset.filter(lambda x: x["split"] == "val")
# relevant_columns = ["image", "caption"]
# train_dataset = train_data.select_columns(relevant_columns)
# test_dataset = test_data.select_columns(relevant_columns)
# valid_dataset = valid_data.select_columns(relevant_columns)



In [16]:
# Load the dataset & dataloader
dataset = FlickrDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [27]:
# Load the model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Freeze the CLIP model
for param in clip_model.vision_model.parameters():
    param.requires_grad = False
for param in clip_model.text_model.parameters():
    param.requires_grad = False

# Initialize the decoder
decoder = Decoder(num_layers=1, embedding_dim=64, num_heads=4, ff_dim=128).to(DEVICE)

# Initialize the vocabulary projection
vocab_projection = nn.Linear(64, dataset.vocab_size).to(DEVICE)

# Optimizer
optimizer = optim.Adam(list(decoder.parameters()) + list(vocab_projection.parameters()), lr=0.001)

# Schedule learning rate
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

In [36]:

def train_loop(dataloader, clip_model, decoder, vocab_projection, optimizer, scheduler, num_epochs: int= 1):
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=True)
        for batch_idx, batch in enumerate(progress_bar):
            # batch = next(iter(dataloader))
            images, captions = batch
        images = images.to(DEVICE)
        captions = captions.to(DEVICE)

        clip_model.eval()
        with torch.no_grad():
            image_features = clip_model.get_image_features(pixel_values=images)
            text_features = clip_model.text_model(
                input_ids=captions,
                output_hidden_states=True
            )
            last_hidden_state = text_features.last_hidden_state  # shape: (batch size, sequence length, hidden size)
            # print(last_hidden_state.shape, image_features.shape)

            # Zero the gradients
            optimizer.zero_grad()

            # Concatenate the image features and the last hidden state
            decoder_input = torch.cat((image_features.unsqueeze(1), last_hidden_state), dim=1)

            # Pass the concatenated features to the decoder
            logits = decoder(decoder_input)

            # Project the logits to the vocabulary size
            logits = vocab_projection(logits)

            # Remove the CLS token
            logits = logits[:, 1:, :].contiguous()

            # Flatten the logits and targets
            logits = logits.view(-1, logits.size(-1))
            targets = captions.view(-1)

            # Compute the loss
            loss = nn.CrossEntropyLoss()(logits, targets)
            epoch_loss += loss.item()

            # Backpropagate the loss
            loss.backward()
            optimizer.step()
            scheduler.step()

            if batch_idx % 25 == 0:
                tqdm.write(f"Epoch {epoch + 1}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")

    epoch_loss /= len(dataloader)
    tqdm.write(f"Epoch {epoch + 1} completed, Epoch loss: {epoch_loss:.4f}")





In [37]:
num_epochs = 1

train_loop(dataloader, clip_model, decoder, vocab_projection, optimizer, scheduler, num_epochs)

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

In [58]:
234*7/60







27.3