In [1]:
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [3]:
import json
import urllib.request
import os

# Download data files if not exist
data_urls = {
    'imgcap.npz': 'https://github.com/nongaussian/class-2026-lginnotek-llm/raw/refs/heads/main/imgcap/imgcap.npz',
    'imgcap_output_vocab.json': 'https://raw.githubusercontent.com/nongaussian/class-2026-lginnotek-llm/refs/heads/main/imgcap/imgcap_output_vocab.json'
}

for filename, url in data_urls.items():
    if not os.path.exists(filename):
        print(f'Downloading {filename}...')
        urllib.request.urlretrieve(url, filename)
        print(f'{filename} downloaded.')
    else:
        print(f'{filename} already exists.')

npzfile = np.load('imgcap.npz')
input_imgs = npzfile['input_imgs']
output_seqs = npzfile['output_seqs']

with open("imgcap_output_vocab.json", "rb") as f:
    output_vocab = json.load(f)

y_vocab_size = len(output_vocab)
inverse_output_vocab = {index: token for token, index in output_vocab.items()}

# PyTorch uses (N, C, H, W) format, so we expand and transpose
input_imgs = np.expand_dims(input_imgs, axis=1)  # (N, 1, H, W)

Downloading imgcap.npz...
imgcap.npz downloaded.
Downloading imgcap_output_vocab.json...
imgcap_output_vocab.json downloaded.


In [None]:
input_imgs.shape

In [None]:
output_seqs.shape

### Check image and caption

In [None]:
plt.imshow(input_imgs[3931, 0, :, :], cmap='gray')

In [None]:
output_seqs[3931]

In [None]:
def text_decoding(line, invvocab):
    return [invvocab[x] for x in line]

In [None]:
print(text_decoding(output_seqs[3931], inverse_output_vocab))

### Building models

In [None]:
BATCH_SIZE = 16
embedding_dim = 1024
latent_dim = 100
output_vocab_size = len(output_vocab)

In [None]:
class ImageCaptionDataset(Dataset):
    def __init__(self, images, captions):
        self.images = torch.FloatTensor(images)
        self.captions = torch.LongTensor(captions)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        return self.images[idx], self.captions[idx]

dataset = ImageCaptionDataset(input_imgs, output_seqs)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# Get example batch
for example_input_batch, example_target_batch in dataloader:
    plt.imshow(example_input_batch[5, 0].numpy(), cmap='gray')
    break

In [None]:
# Compact batch tensors
max_len = (example_target_batch != 0).sum(dim=1).max().item()
example_target_batch = example_target_batch[:, :max_len]
example_target_batch.shape

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=100):
        super(Encoder, self).__init__()
        # LeNet-like architecture
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)  # padding='same'
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)  # padding='valid'
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(16 * 5 * 5, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, latent_dim)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
encoder = Encoder(latent_dim).to(device)

In [None]:
example_input_batch.shape

In [None]:
last_state = encoder(example_input_batch.to(device))
print(last_state.shape)

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, latent_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.gru = nn.GRU(embedding_dim, latent_dim, batch_first=True)
        self.fc = nn.Linear(latent_dim, vocab_size)
    
    def forward(self, target_seq, initial_state):
        # target_seq: (batch_size, seq_len)
        # initial_state: (batch_size, latent_dim)
        
        embedded = self.embedding(target_seq)  # (batch_size, seq_len, embedding_dim)
        
        # GRU expects hidden state of shape (num_layers, batch_size, latent_dim)
        hidden = initial_state.unsqueeze(0)  # (1, batch_size, latent_dim)
        
        output, hidden = self.gru(embedded, hidden)
        # output: (batch_size, seq_len, latent_dim)
        # hidden: (1, batch_size, latent_dim)
        
        logits = self.fc(output)  # (batch_size, seq_len, vocab_size)
        
        return logits, hidden.squeeze(0)  # return hidden as (batch_size, latent_dim)

In [None]:
decoder = Decoder(y_vocab_size, embedding_dim, latent_dim).to(device)
logits, s1 = decoder(example_target_batch.to(device), last_state)

In [None]:
print(logits.shape, s1.shape)

### Loss function

In [None]:
def batch_loss(y_true, y_pred):
    # y_true: (batch_size, seq_len)
    # y_pred: (batch_size, seq_len, vocab_size)
    
    # Compute cross entropy loss
    loss = F.cross_entropy(y_pred.reshape(-1, y_pred.size(-1)), 
                           y_true.reshape(-1), 
                           reduction='none')
    loss = loss.view(y_true.shape)
    
    # Apply mask for padding tokens
    mask = (y_true != 0).float()
    loss = loss * mask
    
    return loss.sum() / mask.sum()

In [None]:
batch_loss(example_target_batch[:, 1:].to(device), logits[:, :-1, :])

### Training

In [None]:
# Re-initialize models
encoder = Encoder(latent_dim).to(device)
decoder = Decoder(y_vocab_size, embedding_dim, latent_dim).to(device)

# Combine parameters for optimizer
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

In [None]:
def predict(x_batch, y_batch, training=True):
    last_state = encoder(x_batch)
    logits, _ = decoder(y_batch, last_state)
    return logits

In [None]:
def train_step(x_batch, y_batch):
    # 1. compact batch tensors
    max_len = (y_batch != 0).sum(dim=1).max().item()
    y_batch = y_batch[:, :max_len]
    
    optimizer.zero_grad()
    
    # 2. encoder & decoder
    logits = predict(x_batch, y_batch)
    
    # 3. loss
    loss = batch_loss(y_batch[:, 1:], logits[:, :-1, :])
    
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [None]:
for epoch in range(50):
    start = time.time()
    
    encoder.train()
    decoder.train()
    
    loss_sum = 0
    for x_batch, y_batch in dataloader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        loss = train_step(x_batch, y_batch)
        loss_sum += loss
    
    print('Time for epoch {} is {:.2f} sec: training loss = {:.4f}'.format(
        epoch + 1, time.time() - start, loss))

### Test translation

In [None]:
def beam_translate(x_test, max_steps=100, k=16):
    encoder.eval()
    decoder.eval()
    
    with torch.no_grad():
        # compute encoder and get hidden status
        s1 = encoder(x_test)
        
        # init candidates
        bos_token = output_vocab['<bos>']
        last_token = torch.tensor([[bos_token]], device=device)
        candidates = [(0., last_token, s1, [bos_token], False)]
        
        for _ in range(max_steps):
            new_candidates = []
            
            for score, token, hidden, output_seq, eos in candidates:
                # if the candidate already ends,
                if eos:
                    new_candidates.append((score, token, hidden, output_seq, eos))
                    continue
                
                # compute the prob. of following tokens
                logits, hidden = decoder(token, hidden)
                # shape of logits (1, 1, vocab_size)
                probs = F.log_softmax(logits, dim=2)
                
                # use the token with the top-k logits as the input
                # of the decoder at the next time step
                values, indices = torch.topk(probs.squeeze(), k=k)
                
                for prob, idx in zip(values, indices):
                    idx_int = idx.item()
                    # if prediction is eos, output sequence is complete
                    eos = (idx_int == output_vocab['<eos>'])
                    
                    last_token = torch.tensor([[idx_int]], device=device)
                    new_candidates.append(
                        (score + prob.item(), last_token, hidden,
                         output_seq + [idx_int], eos))
            
            candidates = sorted(new_candidates, key=lambda t: -t[0])[:k]
        
        return [(candidates[i][0], ' '.join([inverse_output_vocab[x] for x in candidates[i][3]])) 
                for i in range(k)]

In [None]:
# Get a fresh batch for testing
for example_input_batch, example_target_batch in dataloader:
    example_input_batch = example_input_batch.to(device)
    break

res = beam_translate(example_input_batch[1:2])
res[0]

In [None]:
plt.imshow(example_input_batch[1, 0].cpu().numpy(), cmap='gray')