# CS5481 Data Engineering: Tutorial on Representation Learning and Multimodal Learning

Welcome everyone to Tutorial 10. Today we're focusing on two key concepts in deep learning: representation learning and multimodal learning.


## Part 1: Representation Learning

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm

# Load and preprocess the MNIST dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match ResNet50 input size
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Convert grayscale to RGB
])

# Load datasets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Define the model using ResNet50 for feature extraction and a fully connected layer for classification
class ResNet50Classifier(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet50Classifier, self).__init__()
        
        # Load pre-trained ResNet50 and remove the classification layer
        self.resnet50 = models.resnet50(pretrained=True)
        # Freeze the parameters of ResNet50
        for param in self.resnet50.parameters():
            param.requires_grad = False
            
        # Remove the final fully connected layer of ResNet50
        self.features = nn.Sequential(*list(self.resnet50.children())[:-1])
        
        # Add a fully connected layer for classification
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        # Extract features using ResNet50
        features = self.features(x)
        # Classify using the fully connected layer
        output = self.fc(features)
        return output

# Initialize the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50Classifier(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

# Training function
def train(model, train_loader, criterion, optimizer, device, epochs=5):
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            progress_bar.set_postfix({
                'loss': running_loss / (progress_bar.n + 1),
                'acc': 100 * correct / total
            })
        
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100*correct/total:.2f}%')

# Evaluation function
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Testing'):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# Train the model
print(f"Training on {device}")
train(model, train_loader, criterion, optimizer, device, epochs=5)

# Evaluate the model
evaluate(model, test_loader, device)

# Save the model
torch.save(model.state_dict(), 'resnet50_mnist_model.pth')

## Part 2: Multimodal Learning

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os

# Configuration parameters
vocab_threshold = 4
embedding_dim = 256
hidden_dim = 512
batch_size = 16
num_epochs = 3

# Simple vocabulary class
class Vocabulary:
    def __init__(self):
        self.word2idx = {'<pad>': 0, '<start>': 1, '<end>': 2, '<unk>': 3}
        self.idx2word = {0: '<pad>', 1: '<start>', 2: '<end>', 3: '<unk>'}
        self.idx = 4
        
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __call__(self, word):
        if word not in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]
    
    def __len__(self):
        return len(self.word2idx)

# Simple tokenization function - using space splitting and punctuation handling
def simple_tokenize(text):
    # Treat punctuation as separate tokens
    for punct in ',.!?;:':
        text = text.replace(punct, f' {punct} ')
    # Split by spaces
    return [token.lower() for token in text.split() if token]

# Build simple vocabulary
def build_vocab(token_file):
    vocab = Vocabulary()
    word_counts = {}
    
    # Read caption file
    with open(token_file, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) == 2:
                caption = parts[1].lower()
                tokens = simple_tokenize(caption)
                for token in tokens:
                    if token not in word_counts:
                        word_counts[token] = 0
                    word_counts[token] += 1
    
    # Add words that appear more than threshold times
    for word, count in word_counts.items():
        if count >= vocab_threshold:
            vocab.add_word(word)
    
    print(f"Vocabulary size: {len(vocab)}")
    return vocab

# Flickr8k dataset class
class Flickr8kDataset(Dataset):
    def __init__(self, img_dir, caption_file, img_list_file, vocab, transform=None):
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform
        
        # Read image list
        with open(img_list_file, 'r') as f:
            self.img_list = [line.strip() for line in f]
        
        # Read and map captions
        self.captions = {}
        with open(caption_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) == 2:
                    img_id = parts[0].split('#')[0]
                    caption = parts[1].lower()
                    if img_id in self.img_list:
                        if img_id not in self.captions:
                            self.captions[img_id] = []
                        self.captions[img_id].append(caption)
        
        # Keep only images with captions
        self.img_list = [img for img in self.img_list if img in self.captions]
        print(f"Dataset contains {len(self.img_list)} images")
        
    def __getitem__(self, idx):
        img_id = self.img_list[idx]
        
        # Load image
        img_path = os.path.join(self.img_dir, img_id)
        image = Image.open(img_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        
        # Randomly select one caption
        caption = np.random.choice(self.captions[img_id])
        
        # Tokenize and convert to indices
        tokens = simple_tokenize(caption)
        caption = []
        caption.append(self.vocab('<start>'))
        caption.extend([self.vocab(token) for token in tokens])
        caption.append(self.vocab('<end>'))
        
        return image, torch.Tensor(caption).long()
    
    def __len__(self):
        return len(self.img_list)

# Simplified collate function for DataLoader
def collate_fn(data):
    images, captions = zip(*data)
    
    # Get lengths and create padded tensor
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    
    # Padding
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    
    # Stack images
    images = torch.stack(images, 0)
    
    return images, targets, lengths

# Image encoder (ResNet50 simplified)
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        # Use pretrained ResNet50
        resnet = models.resnet50(pretrained=True)
        # Remove the final fully connected layer
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)
        
        # Freeze parameters
        for param in self.resnet.parameters():
            param.requires_grad = False
        
    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.embed(features)
        return features

# Text decoder (simplified, avoiding pack_padded_sequence)
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, features, captions, lengths):
        # Embed all caption words
        embeddings = self.embed(captions)
        
        # Add image features as first input
        batch_size = features.size(0)
        features = features.unsqueeze(1)
        embeddings = torch.cat((features, embeddings[:, :-1]), 1)
        
        # LSTM forward pass (without pack_padded_sequence)
        hiddens, _ = self.lstm(embeddings)
        
        # Predict next word
        outputs = self.linear(hiddens)
        
        return outputs
    
    def sample(self, features, max_len=20):
        """Generate captions (inference mode)"""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        states = None
        
        for i in range(max_len):
            hiddens, states = self.lstm(inputs, states)
            outputs = self.linear(hiddens.squeeze(1))
            _, predicted = outputs.max(1)
            sampled_ids.append(predicted)
            
            # Stop condition
            if (predicted == 2).sum() == predicted.size(0): # All samples reached <end>
                break
                
            # Input for next time step
            inputs = self.embed(predicted).unsqueeze(1)
        
        return torch.stack(sampled_ids, 1)

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Define paths
    data_dir = 'flickr8k_data'
    img_dir = os.path.join(data_dir, 'images')
    token_file = os.path.join(data_dir, 'Flickr8k.token.txt')
    train_file = os.path.join(data_dir, 'Flickr_8k.trainImages.txt')
    val_file = os.path.join(data_dir, 'Flickr_8k.devImages.txt')
    
    # Image transformation
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    
    # Build vocabulary
    vocab = build_vocab(token_file)
    
    # Create datasets and data loaders
    train_dataset = Flickr8kDataset(img_dir, token_file, train_file, vocab, transform)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    
    val_dataset = Flickr8kDataset(img_dir, token_file, val_file, vocab, transform)
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )
    
    # Initialize models
    encoder = EncoderCNN(embedding_dim).to(device)
    decoder = DecoderRNN(embedding_dim, hidden_dim, len(vocab)).to(device)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<pad>'])
    params = list(decoder.parameters()) + list(encoder.embed.parameters())
    optimizer = torch.optim.Adam(params, lr=0.001)
    
    # Training loop
    print("Starting training...")
    for epoch in range(num_epochs):
        # Training phase
        encoder.train()
        decoder.train()
        train_loss = 0
        
        for i, (images, captions, lengths) in enumerate(train_loader):
            images = images.to(device)
            captions = captions.to(device)
            
            # Forward pass
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            
            # Calculate loss - using reshape to match dimensions
            outputs = outputs.reshape(-1, outputs.shape[2])
            targets = captions.reshape(-1)
            loss = criterion(outputs, targets)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Show progress
            if i % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_loader)}], Loss: {loss.item():.4f}")
        
        # Generate sample captions
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            # Get a batch of validation data
            images, _, _ = next(iter(val_loader))
            images = images[:3].to(device)  # Use only 3 images as examples
            
            # Generate captions
            features = encoder(images)
            sampled_ids = decoder.sample(features)
            
            # Convert indices to words
            for i in range(len(sampled_ids)):
                sampled_caption = []
                for word_id in sampled_ids[i].cpu().numpy():
                    word = vocab.idx2word.get(word_id, '<unk>')
                    if word == '<end>':
                        break
                    if word not in ['<start>', '<pad>']:
                        sampled_caption.append(word)
                
                print(f"Generated caption {i+1}: {' '.join(sampled_caption)}")
    
    # Save models
    torch.save(encoder.state_dict(), 'encoder.pth')
    torch.save(decoder.state_dict(), 'decoder.pth')
    print("Models saved.")

if __name__ == '__main__':
    main()