In [2]:
# Import necessary libraries
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPTokenizer, CLIPModel
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn



In [3]:
# load dataset and create train and test sets
raw_dataset = load_dataset("nlphuji/flickr30k", split='test[:5000]')

# Split the dataset into training and testing sets
train_test_split = raw_dataset.train_test_split(test_size=0.3)
train = train_test_split['train']
test = train_test_split['test']

print(test)
print(train)

Dataset({
    features: ['image', 'caption', 'sentids', 'split', 'img_id', 'filename'],
    num_rows: 1500
})
Dataset({
    features: ['image', 'caption', 'sentids', 'split', 'img_id', 'filename'],
    num_rows: 3500
})


# Initialize CLIP processor and tokenizer
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Current device:', device)

Current device: cuda


In [6]:
class CaptionDataset(Dataset):
    def __init__(self, dataset, clip_model_name="openai/clip-vit-base-patch32", device=device):
        self.image = dataset['image']
        self.caption_list = dataset['caption']
        
        self.device = device

        self.processor = CLIPProcessor.from_pretrained(clip_model_name)
        self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
        self.clip_model = CLIPModel.from_pretrained(clip_model_name).eval().to(self.device)

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        image = self.image[idx]
        caption_list = self.caption_list[idx]
        
        # ---- Encode image with CLIP ----
        img_tensor = self.processor(images=image, return_tensors="pt").to(self.device)
        #print('IMG TENSOR SHAPE', img_tensor.shape) # channels, height, width
        
        # ---- Tokenize input caption ----
        caption = caption_list[0] # get the first caption in the list
        #print('caption len:', len(caption))
        tokens = self.tokenizer(caption, padding="max_length", max_length=32, return_tensors="pt", truncation=True)

        input_ids_full = tokens["input_ids"].to(self.device)  # [1, seq_len]
        #print('text_input_ids_full shape:', input_ids_full.shape)
        mask = tokens["attention_mask"].to(self.device) # get the mask out

        with torch.no_grad():
            # Use only embedding layer from CLIP
            text_embeddings = self.clip_model.text_model.embeddings(input_ids_full).squeeze(0).to(self.device)

            # Get the CLIP encoded image embeddings
            patch_embeddings = self.clip_model.vision_model(**img_tensor).last_hidden_state[:, 1:, :].squeeze(0).to(self.device) # shape: [1, num_patches, hidden_dim]
            #print('Patch embeddings shape:', patch_embeddings.shape)           
            
            
        target_ids = input_ids_full.squeeze(0).to(self.device)

        #print('IMG EMBEDDINGS SHAPE', patch_embeddings.shape)
        #print('TEXT EMBEDDINGS SHAPE', text_embeddings.shape)
        #print('TARGET IDS SHAPE', target_ids.shape)
        #print('MASK SHAPE', mask.shape)

        return patch_embeddings, text_embeddings, target_ids, mask
            


In [7]:


test_caption = CaptionDataset(test)
train_caption = CaptionDataset(train)

# Create a DataLoader
test_dataloader = DataLoader(test_caption, batch_size=10, shuffle=False)  # Adjust batch_size as needed
train_dataloader = DataLoader(train_caption, batch_size=10, shuffle=False)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
# Iterate over the DataLoader and print the outputs
for batch_idx, (patch_embeddings, text_embeddings, target_ids, mask) in enumerate(test_dataloader):
    print('\nImage Embeddings shape:', patch_embeddings.shape)
    print('Text Embeddings shape:', text_embeddings.shape)

    # Move all to device
    patch_embeddings = patch_embeddings.to(device)
    text_embeddings = text_embeddings.to(device)
    target_ids = target_ids.to(device)
    mask = mask.to(device)

    image_input_dim = patch_embeddings.shape[2]# get dimension from patch embeddings
    text_embed_dim = text_embeddings.shape[2]

    # Move projection layer to device
    image_projection_layer = nn.Linear(image_input_dim, text_embed_dim).to(device)
    img_features = image_projection_layer(patch_embeddings)
    print('\nProjected img shape:', img_features.shape)
    print('Projected text shape:', text_embeddings.shape)

    print('Text tensor type:', text_embeddings.dtype)
    print('Image tensor type:', img_features.dtype)
    break  # Remove this break to iterate over the entire subset


Image Embeddings shape: torch.Size([10, 49, 768])
Text Embeddings shape: torch.Size([10, 32, 512])

Projected img shape: torch.Size([10, 49, 512])
Projected text shape: torch.Size([10, 32, 512])
Text tensor type: torch.float32
Image tensor type: torch.float32


# create the image captioning model

In [8]:
import torch.nn as nn

image_input_dim = patch_embeddings.squeeze(1).shape[2] # get dimension from patch embeddings
text_embed_dim = text_embeddings.squeeze(1).shape[2]

image_projection_layer = nn.Linear(image_input_dim, text_embed_dim)
img_features = image_projection_layer(patch_embeddings)
print('\nProjected img shape:', img_features.shape)
print('Projected text shape:', text_embeddings.shape)

NameError: name 'patch_embeddings' is not defined

⬆️ The code above projects the text and image embeddings into the same shape so now we can send it! 

In [43]:
# Create Masked Self Attention Head
class MaskedAttentionHead(nn.Module):
    def __init__(self, embedding_dim, head_dim):
        super(MaskedAttentionHead, self).__init__()
        self.head_dim = head_dim

        # Linear projections for query, key, value
        self.weight_q = nn.Linear(embedding_dim, head_dim)
        self.weight_k = nn.Linear(embedding_dim, head_dim)
        self.weight_v = nn.Linear(embedding_dim, head_dim)

        self.linear_projection = nn.Linear(head_dim, embedding_dim)

    def forward(self, decoder_sequence):
        # embedded decoder sequence shape: [batch_size, seq_length, embedding_dim]

        # Project to head dimension
        Q = self.weight_q(decoder_sequence)
        K = self.weight_k(decoder_sequence)
        V = self.weight_v(decoder_sequence)

        # Make the mask
        seq_len = decoder_sequence.shape[1]
        mask = torch.triu(torch.ones(seq_len, seq_len, device=decoder_sequence.device), diagonal=1)
        mask = mask.masked_fill(mask==1, float('-inf'))

        # Calculate attention scores (scaled dot product)
        A = torch.einsum('bid,bjd->bij', Q, K)
        A = A / (self.head_dim ** 0.5) 

        A = A + mask
        # Apply softmax
        A = torch.softmax(A, dim=-1)

    
        #  Apply attention weights to values
        H = torch.einsum('bij,bjd->bid', A, V)
        
        # Add projection layer for output to return back to the original embedding dimension
        #output = self.linear_projection(H)

        return H

In [44]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads

     
        self.heads = nn.ModuleList(
            [MaskedAttentionHead(embedding_dim, self.head_dim) for _ in range(num_heads)]
            )
        
        # The output of the CrossAttention Head and MaskedAttentionHead still needs to be projected
        # Back to the embedding dimensions of the head_dim x vocab_size
        
        self.output_projection = nn.Linear(num_heads * self.head_dim, embedding_dim)

        

    def forward(self, decoder_sequence):
        # decoder_sequence: [batch_size, seq_length, embedding_dim]
        # encoder_output: [batch_size, num_patches, embedding_dim] (only used in cross-attention)
        # mask: [batch_size, seq_length, seq_length] (only used in self-attention)

        # Process each head
        head_outputs = []
        for head in self.heads:
                # For masked self-attention, we only need decoder sequence and mask
                head_output = head(decoder_sequence)
                #print("\nmasked attention head output shape: ", head_output.shape)
                
                head_outputs.append(head_output)

        # Concatenate head outputs
        concat_heads = torch.cat(head_outputs, dim=-1)
        
        # Project back to embedding dimension
        output = self.output_projection(concat_heads)
        #print("Multihead attention output shape: ", output.shape)
        
        return output



In [45]:
class DecoderBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads, mlp_dimension):
        super(DecoderBlock, self).__init__()
        
        # First layer norm
        self.ln1 = nn.LayerNorm(embedding_dim)
        
        # Masked multi-head attention for decoder sequence self-attention
        self.masked_mha = MultiHeadAttention(embedding_dim, num_heads)
        
        
        # Third layer norm
        self.ln2 = nn.LayerNorm(embedding_dim)
        
        # Feed forward network
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, mlp_dimension),
            nn.ReLU(),
            nn.Linear(mlp_dimension, embedding_dim)
        )

    def forward(self, decoder_sequence):
        # decoder_sequence: the input sequence to decode (e.g., [START, 1, 2, 3])
        # encoder_output: the encoded image from the encoder
        # mask: causal mask to prevent attending to future tokens

        # First masked self-attention block with residual connection
        # This allows decoder sequence to attend to its own past tokens
        # First masked self-attention
        residual = decoder_sequence
        decoder_sequence = self.ln1(decoder_sequence)
        decoder_sequence = self.masked_mha(decoder_sequence)
        decoder_sequence = residual + decoder_sequence

        
        # # FFN block with residual connection
        residual = decoder_sequence
        decoder_sequence = self.ln2(decoder_sequence)
        decoder_sequence = self.ffn(decoder_sequence)
        decoder_sequence = residual + decoder_sequence
        
        return decoder_sequence



In [46]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim, num_heads, mlp_dimension, num_layers, input_sequence_length, vocab_size):
        super(Decoder, self).__init__()
        
        self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
    
        # Create positional embeddings for decoder sequence ONCE during initialization
        #self.positional_embeddings = nn.Parameter(
        #    torch.randn(1, input_sequence_length, embedding_dim),
        #    requires_grad=True
        #)
        
        # Create decoder blocks
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(embedding_dim, num_heads, mlp_dimension)
            for _ in range(num_layers)
        ])
        
        # Final layer norm
        self.final_ln = nn.LayerNorm(embedding_dim)

        # Output projection to vocabulary size
        # This converts decoder features to logits over possible next tokens
        self.output_projection = nn.Linear(embedding_dim, vocab_size)

    def forward(self, decoder_sequence, return_logits=True):
        # decoder_sequence: the input sequence to decode (e.g., [START, 1, 2, 3])
        # encoder_output: the encoded image from the encoder
        # return_logits: whether to return prediction logits or just decoder features
        # by default, we return logits

        #embedded_decoder_sequence = self.embedding_layer(decoder_sequence) # not neee because its already embedded

        # Add positional embeddings to decoder sequence
        #decoder_sequence = embedded_decoder_sequence + self.positional_embeddings
        
        # Pass through decoder blocks
        for block in self.decoder_blocks:
            decoder_sequence = block(decoder_sequence)
        
        # Apply final layer norm
        decoder_features = self.final_ln(decoder_sequence)
        
        if return_logits:
            # Convert features to logits for prediction
            # Shape: [batch_size, seq_length, vocab_size]
            logits = self.output_projection(decoder_features)
            return logits
        else:
            # Return decoder features if needed
            return decoder_features

### MODEL

In [47]:
import torch.optim as optim

# Parameters for decoder
embedding_dim = 512
num_heads = 8
mlp_dimension = 2048
num_layers = 2
input_sequence_length = 32
vocab_size = 49408

decoder = Decoder(embedding_dim, num_heads, mlp_dimension, num_layers, input_sequence_length, vocab_size)
#decoder = SimpleCaptionDecoder(embedding_dim, num_heads, mlp_dimension, num_layers, vocab_size)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(decoder.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):

    # Iterate over the DataLoader and print the outputs
    for patch_embeddings, text_embeddings, target_ids, mask in dataloader:

        
        print('\nImage Embeddings shape:', patch_embeddings.shape)
        print('Text Embeddings shape:', text_embeddings.shape)

        image_input_dim = patch_embeddings.shape[2] # get dimension from patch embeddings
        text_embed_dim = text_embeddings.shape[2]
        print('Target ids shape:', target_ids.shape)

        image_projection_layer = nn.Linear(image_input_dim, text_embed_dim)
        img_features = image_projection_layer(patch_embeddings)
        print('\nProjected img shape:', img_features.shape)
        print('Projected text shape:', text_embeddings.shape)

        # Move to device
        encoder_output = img_features.to(device)
        decoder_inputs = text_embeddings.to(device).float()
        targets = target_ids.to(device)
        
       
        print('Text embeddings type:', decoder_inputs.dtype)
        print('Image features type:', encoder_output.dtype)
        print('Target ids type:', targets.dtype)

        # Forward pass through the decoder
        logits = decoder(decoder_inputs)

        # Compute loss
        # Shift logits and targets to align for cross-entropy
        logits = logits[:, :-1].contiguous().view(-1, logits.size(-1))
        targets = target_ids[:, 1:].contiguous().view(-1)
        loss = criterion(logits, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

print("Training complete.")

## Validation
model.eval()
with torch.no_grad():
    for batch in val_loader:
        images, decoder_inputs, targets = batch
        outputs = model(images, decoder_inputs)
        loss = criterion(outputs, targets)
        print(f"Validation Loss: {loss.item()}")


Image Embeddings shape: torch.Size([10, 49, 768])
Text Embeddings shape: torch.Size([10, 32, 512])
Target ids shape: torch.Size([10, 32])

Projected img shape: torch.Size([10, 49, 512])
Projected text shape: torch.Size([10, 32, 512])
Text embeddings type: torch.float32
Image features type: torch.float32
Target ids type: torch.int64

Image Embeddings shape: torch.Size([10, 49, 768])
Text Embeddings shape: torch.Size([10, 32, 512])
Target ids shape: torch.Size([10, 32])

Projected img shape: torch.Size([10, 49, 512])
Projected text shape: torch.Size([10, 32, 512])
Text embeddings type: torch.float32
Image features type: torch.float32
Target ids type: torch.int64

Image Embeddings shape: torch.Size([10, 49, 768])
Text Embeddings shape: torch.Size([10, 32, 512])
Target ids shape: torch.Size([10, 32])

Projected img shape: torch.Size([10, 49, 512])
Projected text shape: torch.Size([10, 32, 512])
Text embeddings type: torch.float32
Image features type: torch.float32
Target ids type: torch.i

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os
from datetime import datetime
import wandb
import time


def train_transformer(
    model,
    train_loader,
    val_loader,
    num_epochs,
    learning_rate,
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    
    # Initialize wandb
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_name = f"full-transformer-{timestamp}"
    wandb.init(project="full-transformer-mnist",
               name=run_name,
               group="experiment-1",
               config={
                   "learning_rate": learning_rate,
                   "num_epochs": num_epochs,
                   "batch_size": train_loader.batch_size,
                   "embedding_dim": model.embedding_dim,
                   "num_heads": model.num_heads,
                   "mlp_dimension": model.mlp_dimension,
                   "num_layers": model.num_layers,
                   "max_seq_length": model.num_patches,
                   "vocab_size": model.vocab_size
               })
    
    # Move model to device
    model = model.to(device)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # TRAINING LOOP 
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        # Training phase
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', mininterval=1.0) # udpate every 1 second
        for batch_idx, (input_images, decoder_inputs, targets) in enumerate(train_loader):
            # Move to device
            input_images = input_images.to(device)
            decoder_inputs = decoder_inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(input_images, decoder_inputs)

            # Calculate loss
            # outputs.view(-1, outputs.size(-1)) is to flatten the output to a 1D tensor
            # targets.view(-1) is to flatten the targets to a 1D tensor
            loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update metrics
            total_loss += loss.item()
            _, predicted = outputs.max(-1)
            total_correct += (predicted == targets).sum().item()
            total_samples += targets.numel()
            
            # Calculate current accuracy
            current_acc = (total_correct / total_samples * 100)
            
            
            # Update progress bar
            train_pbar.set_postfix({
                'loss': total_loss / (batch_idx + 1),
                'acc': f'{current_acc:.2f}%'
            })

            train_pbar.update(1)
        
        train_pbar.close()

        
        # Calculate training metrics after all batches
        train_loss = total_loss / len(train_loader)
        train_acc = (total_correct / total_samples * 100)
        print(f"\nTraining completed for epoch {epoch+1}")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")

        # After training phase
        wandb.log({
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "epoch": epoch + 1
        })
            
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_samples = 0
        
        # No need to compute gradients for validation
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', mininterval=1.0)
            for batch_idx, (images, decoder_input, targets) in enumerate(val_pbar):
                # Move to device
                images = images.to(device)
                decoder_input = decoder_input.to(device)
                targets = targets.to(device)
                
                # Forward pass
                outputs = model(images, decoder_input)
                
                # Calculate loss
                loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
                val_loss += loss.item()

                # Update validation metrics
                _, predicted = outputs.max(-1)
                val_correct += (predicted == targets).sum().item()
                val_samples += targets.numel()

                # Calculate current validation accuracy
                current_val_acc = (val_correct / val_samples * 100)
                

                # Update progress bar
                val_pbar.set_postfix({
                    'loss': val_loss / (batch_idx + 1),
                    'acc': f'{current_val_acc:.2f}%'
                })

                val_pbar.update(1)

        val_pbar.close()
                
        # Calculate validation metrics
        val_loss = val_loss / len(val_loader)
        val_acc = (val_correct / val_samples * 100)

        # After validation phase
        wandb.log({
            "val_loss": val_loss,
            "val_accuracy": val_acc,
            "epoch": epoch + 1
        })
        

      # Print epoch summary
        print(f'\nEpoch {epoch+1}/{num_epochs} Summary:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # After training loop completes
    # 1. Save model weights
    model_save_path = f"transformer_weights_{timestamp}_mnistscattered.pt"

    torch.save(model, model_save_path)

    print(f"\nModel weights saved to {model_save_path}")

    # Finish wandb run
    wandb.finish()
            
        
    

In [None]:
# captions produces a list of captions
#captions = [caption for caption_list[0] for caption_list in test['caption']]  

captions = []
for caption_list in test['caption']:
    for caption in caption_list[0]:
        captions.append(caption)



In [7]:
import torch.nn as nn
from tqdm import tqdm

class ImageCaptioningModel(nn.Module):
    def __init__(self, decoder, clip_model_name='openai/clip-vit-base-patch32'):
        super().__init__()
        self.decoder = decoder
        self.clip = CLIPModel.from_pretrained(clip_model_name)
        self.clip.visual.requires_grad_(False)  # freeze vision encoder

    def forward(self, images, decoder_inputs, tgt_mask=None, tgt_key_padding_mask=None):
        # Get image features from CLIP vision encoder
        vision_outputs = self.clip.vision_model(pixel_values=images)
        img_features = vision_outputs.last_hidden_state.mean(dim=1).unsqueeze(1)  # (B, 1, D)

        # Decode using transformer decoder
        logits = self.decoder(
            memory=img_features,
            decoder_inputs=decoder_inputs,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        return logits


In [8]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    pbar = tqdm(dataloader, desc="Training", leave=False)
    for batch in pbar:
        images = batch['images'].squeeze(1).to(device)  # shape (B, C, H, W)
        decoder_inputs = batch['decoder_inputs'].to(device)  # (B, T)
        decoder_targets = batch['decoder_targets'].to(device)  # (B, T)

        optimizer.zero_grad()
        logits = model(images, decoder_inputs)

        # Shift logits and targets to align for cross-entropy
        logits = logits[:, :-1].contiguous().view(-1, logits.size(-1))
        targets = decoder_targets[:, 1:].contiguous().view(-1)

        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})

    return total_loss / len(dataloader)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageCaptioningModel(decoder=Decoder()).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
