In [1]:
import os
import pandas as pd
import numpy as np
import editdistance
import time 


from PIL import Image
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F

import timm  

from sklearn.model_selection import train_test_split

from transformers import TrOCRProcessor, VisionEncoderDecoderModel

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_dir = os.getcwd()

# Same paths as your original code
ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_train.txt') 
images_dir        = os.path.join(base_dir, 'balinese_word_train')

filenames = []
labels    = []

with open(ground_truth_path, 'r', encoding='utf-8') as file:
    for line in file:
        line = line.strip()
        if line:  # Ensure the line is not empty
            parts = line.split(';')
            if len(parts) == 2:
                filename, label = parts
                label = label.lower()
                filenames.append(filename)
                labels.append(label)
            else:
                print(f"Skipping malformed line: {line}")

data = pd.DataFrame({
    'filename': filenames,
    'label': labels
})

label_counts = data['label'].value_counts()

all_text = ''.join(data['label'])
unique_chars = sorted(list(set(all_text)))

# Create character->index starting from 1
char_to_idx = {char: idx + 1 for idx, char in enumerate(unique_chars)}
# Add special tokens
char_to_idx['<PAD>'] = 0
char_to_idx['<UNK>'] = len(char_to_idx)
char_to_idx['<SOS>'] = len(char_to_idx)
char_to_idx['<EOS>'] = len(char_to_idx)

# Reverse mapping
idx_to_char = {v: k for k, v in char_to_idx.items()}

vocab_size = len(char_to_idx)
print(f"Vocabulary size: {vocab_size}")

def encode_label(label, char_to_idx, max_length):
    """
    Converts a label (string) into a list of indices with <SOS>, <EOS>, padding, etc.
    """
    encoded = (
        [char_to_idx['<SOS>']] +
        [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] +
        [char_to_idx['<EOS>']]
    )
    # Pad if needed
    if len(encoded) < max_length:
        encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
    else:
        encoded = encoded[:max_length]
    return encoded

max_label_length = max(len(label) for label in data['label']) + 2  # +2 for <SOS> and <EOS>
data['encoded_label'] = data['label'].apply(lambda x: encode_label(x, char_to_idx, max_label_length))
data['label_length']  = data['label'].apply(len)

rare_labels = label_counts[label_counts < 3].index  # NEW: words that appear <3 times

def custom_split(df, rare_label_list, test_size=0.1, random_state=42):
    # Separate rare words from frequent ones
    rare_df     = df[df['label'].isin(rare_label_list)]
    non_rare_df = df[~df['label'].isin(rare_label_list)]

    #  train/val split for non-rare
    train_nr, val_nr = train_test_split(non_rare_df, test_size=test_size, 
                                        random_state=random_state)

    # Combine rare samples entirely into training
    train_df = pd.concat([train_nr, rare_df], ignore_index=True)
    # Shuffle after combining
    train_df = train_df.sample(frac=1, random_state=random_state).reset_index(drop=True)

    val_df = val_nr.reset_index(drop=True)
    return train_df, val_df

# Call custom_split instead of direct train_test_split
train_data, val_data = custom_split(data, rare_labels, test_size=0.1, random_state=42) 

print(f"Training size: {len(train_data)}; Validation size: {len(val_data)}")

Vocabulary size: 39
Training size: 13972; Validation size: 1050


In [3]:
class BalineseDataset(Dataset):
    def __init__(self, df, images_dir, transform=None):
        self.data       = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.transform  = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name     = self.data.loc[idx, 'filename']
        label        = self.data.loc[idx, 'encoded_label']
        label_length = self.data.loc[idx, 'label_length']

        img_path = os.path.join(self.images_dir, img_name)
        image    = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long)
        return image, label, torch.tensor(label_length, dtype=torch.long)

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5)
    )
])

train_dataset = BalineseDataset(train_data, images_dir, transform=transform)
val_dataset   = BalineseDataset(val_data,   images_dir, transform=transform)

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [5]:
class ResNet18Encoder(nn.Module):
    """
    Encoder that uses a pretrained ResNet18 to extract features of shape 
    [B, H*W, C], which the DecoderRNN can then attend over.
    """

    def __init__(self, pretrained=True):
        super(ResNet18Encoder, self).__init__()
        resnet = models.resnet18(pretrained=pretrained)

        # Remove the classification (fc) layer
        modules = list(resnet.children())[:-2]  # remove the avgpool & fc
        self.cnn = nn.Sequential(*modules)

        # last convolutional block outputs 512 channels
        self.encoder_dim = 512

    def forward(self, x):
        """
        Input shape:  x -> [batch_size, 3, 224, 224]
        Output shape: -> [batch_size, num_patches, encoder_dim]
                       where num_patches = H' * W' from the final feature map
        """
        # pass through ResNet (up to layer4)
        features = self.cnn(x)  # [B, 512, H', W']

        # Flatten the spatial dims
        # shape => [B, 512, H', W'] -> [B, H'*W', 512]
        b, c, h, w = features.shape
        features = features.permute(0, 2, 3, 1)   # [B, H', W', C]
        features = features.reshape(b, -1, c)     # [B, H'*W', C]

        return features

In [6]:
class ResNet50Encoder(nn.Module):
    """
    Encoder that uses a pretrained ResNet50 to extract features of shape 
    [B, H*W, C]. For ResNet50, the final block has 2048 output channels.
    """

    def __init__(self, pretrained=True):
        super(ResNet50Encoder, self).__init__()
        resnet = models.resnet50(pretrained=pretrained)

        # remove the avgpool & fc layers
        modules = list(resnet.children())[:-2]  
        self.cnn = nn.Sequential(*modules)

        # For ResNet50, the last block has 2048 output channels
        self.encoder_dim = 2048

    def forward(self, x):
        # x -> [B, 3, 224, 224]
        features = self.cnn(x)  # [B, 2048, H', W']

        b, c, h, w = features.shape
        features = features.permute(0, 2, 3, 1)   # [B, H', W', C]
        features = features.reshape(b, -1, c)     # [B, H'*W', C]

        return features

In [7]:
class ViTEncoder(nn.Module):
    """
    A simple ViT encoder that extracts patch embeddings as [batch_size, num_patches, hidden_dim].
    We'll use timm to load a pretrained ViT. Then we use .forward_features() to get a
    feature map of shape [B, C, H', W'] for many timm ViT models, which we flatten.
    """
    def __init__(self, model_name="vit_large_patch16_224", pretrained=True):
        super(ViTEncoder, self).__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained)
        # Remove or replace the classification head
        self.vit.head = nn.Identity()

        # timm's ViT typically has an embed_dim attribute
        self.encoder_dim = self.vit.embed_dim

    def forward(self, x):
        """
        :param x: [batch_size, 3, 224, 224]
        :return:  [batch_size, num_patches, encoder_dim]
        """
        # forward_features usually returns [B, hidden_dim, H', W'] or [B, hidden_dim]
        feats = self.vit.forward_features(x)  # [B, hidden_dim, 14, 14] for vit_base_patch16_224

        # Flatten the spatial dimensions
        if feats.dim() == 4:  # [B, C, H, W]
            b, c, h, w = feats.shape
            feats = feats.permute(0, 2, 3, 1).reshape(b, -1, c)  # => [B, H*W, C]

        return feats

In [8]:
class SwinEncoder(nn.Module):
    def __init__(self, model_name="swin_small_patch4_window7_224", pretrained=True):
        """
        A simple Swin Transformer encoder that extracts patch embeddings
        as [batch_size, num_patches, hidden_dim]. We'll use timm to load 
        a pretrained Swin model, remove its classification head, then flatten.
        """
        super().__init__()
        self.swin = timm.create_model(model_name, pretrained=pretrained)
        self.swin.head = nn.Identity()

        # We'll assign encoder_dim dynamically after forward
        self.encoder_dim = None

    def forward(self, x):
        """
        :param x: [batch_size, 3, 224, 224]
        :return:  [batch_size, num_patches, encoder_dim]
        """
        feats = self.swin.forward_features(x)            # [B, C, H, W]
        b, c, h, w = feats.shape
        feats = feats.flatten(2).transpose(1, 2)         # [B, H*W, C]
        # Set encoder_dim once (C)
        if self.encoder_dim is None:
            self.encoder_dim = feats.shape[-1]
        return feats




In [9]:
class HybridEncoder(nn.Module):
    def __init__(self, cnn_encoder, vit_encoder):
        super(HybridEncoder, self).__init__()
        self.cnn_encoder = cnn_encoder
        self.vit_encoder = vit_encoder
        # Combined encoder_dim is the sum of both encoder dimensions.
        # (CNN outputs 512 channels; ViT outputs its own embed_dim.)
        self.encoder_dim = cnn_encoder.encoder_dim + vit_encoder.encoder_dim

    def forward(self, x):
        # Get CNN features: expected shape [B, 49, 512]
        cnn_features = self.cnn_encoder(x)

        # Get ViT features: expected shape [B, 197, vit_dim] for vit_large_patch16_224
        vit_features = self.vit_encoder(x)
        # If the ViT output contains a class token, remove it.
        if vit_features.shape[1] == 197:
            vit_features = vit_features[:, 1:, :]  # Now shape: [B, 196, vit_dim]
            B, tokens, D = vit_features.shape
            # Reshape tokens into a 14x14 grid: [B, 14, 14, D]
            vit_features = vit_features.reshape(B, 14, 14, D)
            # Permute to [B, D, 14, 14] for pooling
            vit_features = vit_features.permute(0, 3, 1, 2)
            # Use adaptive pooling to reduce to a 7x7 grid
            vit_features = F.adaptive_avg_pool2d(vit_features, (7, 7))
            # Permute back to [B, 7, 7, D] and flatten to [B, 49, D]
            vit_features = vit_features.permute(0, 2, 3, 1).reshape(B, -1, D)

        # Concatenate the features along the feature dimension (dim=2)
        # cnn_features: [B, 49, 512] and vit_features: [B, 49, vit_dim]
        hybrid_features = torch.cat([cnn_features, vit_features], dim=2)
        return hybrid_features

In [10]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # transform encoder output
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # transform decoder hidden
        self.full_att    = nn.Linear(attention_dim, 1)
        self.relu        = nn.ReLU()
        self.softmax     = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out:    [batch_size, num_patches, encoder_dim]
        decoder_hidden: [batch_size, decoder_dim]
        """
        att1 = self.encoder_att(encoder_out)                  # [batch_size, num_patches, attention_dim]
        att2 = self.decoder_att(decoder_hidden).unsqueeze(1)  # [batch_size, 1, attention_dim]

        # sum -> relu -> full_att -> squeeze -> softmax
        att  = self.full_att(self.relu(att1 + att2)).squeeze(2)  # [batch_size, num_patches]
        alpha = self.softmax(att)
        # Weighted sum of the encoder_out
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # [batch_size, encoder_dim]
        return attention_weighted_encoding, alpha

class DecoderRNN(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=768, teacher_forcing_ratio=0.5):
        super(DecoderRNN, self).__init__()

        self.attention     = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding     = nn.Embedding(vocab_size, embed_dim)
        self.dropout       = nn.Dropout(p=0.5)

        # [embed_dim + encoder_dim] -> decoder_dim
        self.lstm1 = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim)
        # decoder_dim -> decoder_dim
        self.lstm2 = nn.LSTMCell(decoder_dim, decoder_dim)

        # For initializing the hidden states of both LSTM layers
        self.init_h1 = nn.Linear(encoder_dim, decoder_dim)
        self.init_c1 = nn.Linear(encoder_dim, decoder_dim)
        self.init_h2 = nn.Linear(encoder_dim, decoder_dim)
        self.init_c2 = nn.Linear(encoder_dim, decoder_dim)

        # Gating
        self.f_beta  = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()

        # Final linear layer for output vocab
        self.fc = nn.Linear(decoder_dim, vocab_size)

        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.init_weights()

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        # encoder_out: [batch_size, num_patches, encoder_dim]
        mean_encoder_out = encoder_out.mean(dim=1)  # [batch_size, encoder_dim]
        h1 = self.init_h1(mean_encoder_out)         # [batch_size, decoder_dim]
        c1 = self.init_c1(mean_encoder_out)         # [batch_size, decoder_dim]
        h2 = self.init_h2(mean_encoder_out)         # [batch_size, decoder_dim]
        c2 = self.init_c2(mean_encoder_out)         # [batch_size, decoder_dim]
        return (h1, c1, h2, c2)

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        encoder_out:      [batch_size, num_patches, encoder_dim]
        encoded_captions: [batch_size, max_label_length]
        caption_lengths:  [batch_size, 1]
        """
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out      = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        embeddings = self.embedding(encoded_captions)

        # Initialize hidden states for both LSTM layers
        h1, c1, h2, c2 = self.init_hidden_state(encoder_out)

        decode_lengths    = (caption_lengths - 1).tolist()
        max_decode_length = max(decode_lengths)

        batch_size = encoder_out.size(0)
        vocab_size = self.fc.out_features

        predictions = torch.zeros(batch_size, max_decode_length, vocab_size, device=encoder_out.device)
        alphas      = torch.zeros(batch_size, max_decode_length, encoder_out.size(1), device=encoder_out.device)

        # We'll feed the first token from the input (<SOS>) or from the previous prediction
        prev_tokens = encoded_captions[:, 0].clone()

        for t in range(max_decode_length):
            batch_size_t = sum([l > t for l in decode_lengths])

            attention_weighted_encoding, alpha = self.attention(
                encoder_out[:batch_size_t],
                h1[:batch_size_t]  # use the first LSTM layer's hidden state for attention
            )

            # Apply gating
            gate = self.sigmoid(self.f_beta(h1[:batch_size_t]))
            attention_weighted_encoding = gate * attention_weighted_encoding

            # Teacher forcing?
            use_teacher_forcing = (torch.rand(1).item() < self.teacher_forcing_ratio)
            if use_teacher_forcing:
                current_input = embeddings[:batch_size_t, t, :]
            else:
                current_input = self.embedding(prev_tokens[:batch_size_t].detach())

            # first lstm layer
            h1_next, c1_next = self.lstm1(
                torch.cat([current_input, attention_weighted_encoding], dim=1),
                (h1[:batch_size_t], c1[:batch_size_t])
            )

            # second lstm layer
            h2_next, c2_next = self.lstm2(
                h1_next, (h2[:batch_size_t], c2[:batch_size_t])
            )

            # Use the second LSTM layer's output (h2_next) for final prediction
            preds = self.fc(self.dropout(h2_next))
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :]      = alpha

            # Update prev_tokens with the best predicted token
            _, next_tokens = preds.max(dim=1)
            prev_tokens_ = prev_tokens.clone()
            prev_tokens_[:batch_size_t] = next_tokens.detach()
            prev_tokens = prev_tokens_

            # Update hidden states
            # For samples still in the batch, store the new h1, c1, h2, c2
            h1_new = torch.zeros_like(h1)
            c1_new = torch.zeros_like(c1)
            h2_new = torch.zeros_like(h2)
            c2_new = torch.zeros_like(c2)

            h1_new[:batch_size_t] = h1_next
            c1_new[:batch_size_t] = c1_next
            h2_new[:batch_size_t] = h2_next
            c2_new[:batch_size_t] = c2_next

            h1, c1, h2, c2 = h1_new, c1_new, h2_new, c2_new

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [11]:
class ImageCaptioningTrainer:
    def __init__(self, encoder, decoder, 
                 criterion, encoder_optimizer, decoder_optimizer, 
                 train_loader, val_loader, device, 
                 char_to_idx, idx_to_char, max_label_length,
                 model_name, csv_filename="training_results.csv"):
        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)
        self.criterion = criterion
        self.encoder_optimizer = encoder_optimizer
        self.decoder_optimizer = decoder_optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.char_to_idx = char_to_idx
        self.idx_to_char = idx_to_char
        self.max_label_length = max_label_length
        self.model_name = model_name
        self.csv_filename = csv_filename

        self.train_losses = []
        self.val_losses = []

        self.train_cers = []
        self.val_cers = []

    def fit(self, num_epochs):
        import time
        import os
        import pandas as pd
        
        start_time = time.time()
        
        # Early stopping parameters
        best_val_loss = float('inf')
        early_stop_counter = 0
        early_stop_patience = 5  # stop if no improvement for 3 consecutive epochs

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            train_loss, train_cer = self.train_one_epoch()
            val_loss, val_cer = self.validate_one_epoch(top_n=5)

            print(f"[{epoch+1}/{num_epochs}] "
                  f"Train Loss: {train_loss:.4f}, Train CER: {train_cer:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val CER: {val_cer:.4f}")

            # Store epoch results
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_cers.append(train_cer)
            self.val_cers.append(val_cer)
            
            # Early stopping check: if current val_loss is better, reset counter; otherwise, increment.
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                print(f"Validation loss did not improve. Early stop counter: {early_stop_counter}/{early_stop_patience}")
            
            if early_stop_counter >= early_stop_patience:
                print("Early stopping triggered.")
                break
        
        # Calculate total training time
        end_time = time.time() 
        total_time = end_time - start_time
        hours = int(total_time // 3600)
        minutes = int((total_time % 3600) // 60)

        print(f"\nTraining completed in {hours}h {minutes}m.")

        num_epochs_recorded = len(self.train_losses)
        epoch_cols = [f"epoch{i+1}" for i in range(num_epochs_recorded)]

        # Create the new data block to insert
        new_rows = pd.DataFrame([
            [self.model_name, "training loss"] + self.train_losses,
            [self.model_name, "validation loss"] + self.val_losses,
            [self.model_name, "training cer"] + self.train_cers,
            [self.model_name, "validation cer"] + self.val_cers
        ], columns=["model_name", "mode"] + epoch_cols)
        
        # Check if CSV already exists
        if os.path.exists(self.csv_filename):
            df_existing = pd.read_csv(self.csv_filename)
            df_existing = df_existing[df_existing["model_name"] != self.model_name]
            df_updated = pd.concat([df_existing, new_rows], ignore_index=True)
        else:
            df_updated = new_rows
        
        # Save the updated CSV
        df_updated.to_csv(self.csv_filename, index=False)
        print(f"\nResults have been written to: {self.csv_filename}")

        # Save model weights
        # torch.save(self.encoder.state_dict(), f"encoder_{self.model_name}.pth")
        # torch.save(self.decoder.state_dict(), f"decoder_{self.model_name}.pth")
        # print(f"Encoder and decoder models saved: encoder_{self.model_name}.pth, decoder_{self.model_name}.pth")

    def train_one_epoch(self):
        self.encoder.train()
        self.decoder.train()
        running_loss           = 0.0
        total_edit_distance    = 0
        total_ref_length       = 0

        for batch_idx, (images, labels, label_lengths) in enumerate(self.train_loader):
            images        = images.to(self.device, non_blocking=True)
            labels        = labels.to(self.device, non_blocking=True)
            label_lengths = label_lengths.to(self.device, non_blocking=True)

            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            encoder_out   = self.encoder(images)
            caption_lengths = torch.tensor(
                [self.max_label_length] * labels.size(0)
            ).unsqueeze(1).to(self.device)

            outputs, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
                encoder_out, labels, caption_lengths
            )

            # Targets = encoded captions without the <SOS>
            targets = encoded_captions[:, 1:]

            # Flatten for loss
            outputs_flat = outputs.view(-1, self.decoder.fc.out_features)
            targets_flat = targets.contiguous().view(-1)

            loss = self.criterion(outputs_flat, targets_flat)
            loss.backward()

            self.decoder_optimizer.step()
            self.encoder_optimizer.step()

            running_loss += loss.item()

            # Compute CER for the batch (global style)
            batch_size = labels.size(0)
            _, preds_flat = torch.max(outputs_flat, dim=1)
            preds_seq = preds_flat.view(batch_size, -1)

            for i in range(batch_size):
                pred_indices   = preds_seq[i].detach().cpu().numpy()
                target_indices = targets[i].detach().cpu().numpy()

                mask          = (target_indices != self.char_to_idx['<PAD>'])
                pred_indices  = pred_indices[mask]
                target_indices= target_indices[mask]

                pred_chars    = [self.idx_to_char.get(idx, '') for idx in pred_indices]
                target_chars  = [self.idx_to_char.get(idx, '') for idx in target_indices]
                pred_str      = ''.join(pred_chars)
                target_str    = ''.join(target_chars)

                edit_dist           = editdistance.eval(pred_str, target_str)
                total_edit_distance += edit_dist
                total_ref_length    += len(target_str)

            # if (batch_idx + 1) % 50 == 0:
            #     print(f'Batch {batch_idx + 1}/{len(self.train_loader)} - Loss: {loss.item():.4f}')

        avg_loss = running_loss / len(self.train_loader)
        avg_cer  = total_edit_distance / total_ref_length if total_ref_length > 0 else 0.0
        return avg_loss, avg_cer

    def validate_one_epoch(self, top_n=5):
        self.encoder.eval()
        self.decoder.eval()
        running_loss         = 0.0
        total_edit_distance  = 0
        total_ref_length     = 0

        # each sample’s CER
        sample_cer_info = []

        with torch.no_grad():
            for batch_idx, (images, labels, label_lengths) in enumerate(self.val_loader):
                images        = images.to(self.device, non_blocking=True)
                labels        = labels.to(self.device, non_blocking=True)
                label_lengths = label_lengths.to(self.device, non_blocking=True)

                encoder_out = self.encoder(images)
                caption_lengths = torch.tensor(
                    [self.max_label_length] * labels.size(0)
                ).unsqueeze(1).to(self.device)

                outputs, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
                    encoder_out, labels, caption_lengths
                )
                targets = encoded_captions[:, 1:]

                outputs_flat = outputs.view(-1, self.decoder.fc.out_features)
                targets_flat = targets.contiguous().view(-1)

                loss = self.criterion(outputs_flat, targets_flat)
                running_loss += loss.item()

                batch_size = labels.size(0)
                _, preds_flat = torch.max(outputs_flat, dim=1)
                preds_seq = preds_flat.view(batch_size, -1)

                for i in range(batch_size):
                    pred_indices   = preds_seq[i].detach().cpu().numpy()
                    target_indices = targets[i].detach().cpu().numpy()

                    mask           = (target_indices != self.char_to_idx['<PAD>'])
                    pred_indices   = pred_indices[mask]
                    target_indices = target_indices[mask]

                    pred_chars   = [self.idx_to_char.get(idx, '') for idx in pred_indices]
                    target_chars = [self.idx_to_char.get(idx, '') for idx in target_indices]
                    pred_str     = ''.join(pred_chars)
                    target_str   = ''.join(target_chars)

                    edit_dist = editdistance.eval(pred_str, target_str)
                    ref_len   = len(target_str)
                    cer       = edit_dist / ref_len if ref_len > 0 else 0
    
                    total_edit_distance += edit_dist
                    total_ref_length    += ref_len
    
                    # Store sample info
                    # sample_cer_info.append({
                    #     "pred": pred_str,
                    #     "gt": target_str,
                    #     "cer": cer
                    # })

                    # Print a few samples from the 1st batch
                    # if batch_idx == 0 and i < 3:
                    #     print(f"Sample {i + 1}:")
                    #     print(f"Predicted: {pred_str}")
                    #     print(f"Target   : {target_str}\n")

        avg_loss = running_loss / len(self.val_loader)
        avg_cer  = total_edit_distance / total_ref_length if total_ref_length > 0 else 0.0

        # Sort by CER descending
        sample_cer_info.sort(key=lambda x: x["cer"], reverse=True)
        # Take top_n
        worst_samples = sample_cer_info[:top_n]
    
        # print(f"\n=== Top {top_n} Worst Samples by CER ===")
        # for idx, sample in enumerate(worst_samples):
        #     print(f"[{idx+1}] CER: {sample['cer']:.3f}")
        #     print(f"   Predicted: {sample['pred']}")
        #     print(f"   Ground Truth: {sample['gt']}\n")
       
        return avg_loss, avg_cer


In [12]:
test_ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_test.txt')
test_images_dir        = os.path.join(base_dir, 'balinese_word_test')

test_filenames = []
test_labels    = []

with open(test_ground_truth_path, 'r', encoding='utf-8') as file:
    for line in file:
        line = line.strip()
        if line:
            parts = line.split(';')
            if len(parts) == 2:
                filename, label = parts
                label = label.lower()
                test_filenames.append(filename)
                test_labels.append(label)
            else:
                print(f"Skipping malformed line: {line}")

test_data = pd.DataFrame({
    'filename': test_filenames,
    'label': test_labels
})

# Check for unknown chars in test set
test_chars = set(''.join(test_data['label']))
unknown_chars = test_chars - set(char_to_idx.keys())
print(f"Unknown characters in test labels: {unknown_chars}")

# Encode test labels
max_label_length_test = max(len(lbl) for lbl in test_data['label']) + 2
def encode_label_test(label, char_to_idx, max_length):
    encoded = (
        [char_to_idx['<SOS>']] +
        [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] +
        [char_to_idx['<EOS>']]
    )
    if len(encoded) > max_length:
        encoded = encoded[:max_length]
    else:
        encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
    return encoded

test_data['encoded_label'] = test_data['label'].apply(lambda x: encode_label_test(x, char_to_idx, max_label_length_test))
test_data['label_length']  = test_data['label'].apply(len)

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5)
    )
])

test_dataset = BalineseDataset(test_data, test_images_dir, transform=test_transform)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

Unknown characters in test labels: set()


In [13]:
def inference(encoder, decoder, data_loader, device, char_to_idx, idx_to_char, max_seq_length, test_data):
    encoder.eval()
    decoder.eval()

    eos_idx = char_to_idx['<EOS>']
    results = []

    with torch.no_grad():
        for batch_idx, (images, labels, label_lengths) in enumerate(data_loader):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            batch_size  = images.size(0)
            encoder_out = encoder(images)  # [B, num_patches, encoder_dim]

            # Init LSTM state
            h1, c1, h2, c2 = decoder.init_hidden_state(encoder_out)

            # Start tokens (all <SOS>)
            inputs = torch.full(
                (batch_size,),
                fill_value=char_to_idx['<SOS>'],
                dtype=torch.long,
                device=device
            )

            all_preds = []

            for _ in range(max_seq_length):
                # Embedding
                embeddings = decoder.embedding(inputs)

                # Attention
                attention_weighted_encoding, _ = decoder.attention(encoder_out, h1)

                # Gating
                gate = decoder.sigmoid(decoder.f_beta(h1))
                attention_weighted_encoding = gate * attention_weighted_encoding

                # Pass through LSTM layers
                h1, c1 = decoder.lstm1(
                    torch.cat([embeddings, attention_weighted_encoding], dim=1),
                    (h1, c1)
                )
                h2, c2 = decoder.lstm2(h1, (h2, c2))

                # Get predicted token
                preds = decoder.fc(decoder.dropout(h2))  # [batch_size, vocab_size]
                _, preds_idx = preds.max(dim=1)

                # Feed next token
                all_preds.append(preds_idx.cpu().numpy())
                inputs = preds_idx

            # Reformat predictions to [batch_size, max_seq_length]
            all_preds = np.array(all_preds).T

            for i in range(batch_size):
                pred_indices = all_preds[i]

                # Stop at <EOS> if present
                if eos_idx in pred_indices:
                    first_eos = np.where(pred_indices == eos_idx)[0][0]
                    pred_indices = pred_indices[:first_eos]

                # Convert token indices -> string
                pred_chars = [idx_to_char.get(idx, '') for idx in pred_indices]
                pred_str   = ''.join(pred_chars)

                # Process ground truth
                label_indices = labels[i].cpu().numpy()
                # remove <SOS>
                label_indices = label_indices[1:]

                if eos_idx in label_indices:
                    eos_pos = np.where(label_indices == eos_idx)[0][0]
                    label_indices = label_indices[:eos_pos]
                else:
                    # remove <PAD> if present
                    label_indices = label_indices[label_indices != char_to_idx['<PAD>']]

                label_chars = [idx_to_char.get(idx, '') for idx in label_indices]
                label_str   = ''.join(label_chars)

                global_idx    = batch_idx * batch_size + i
                image_filename= test_data.iloc[global_idx]['filename']

                results.append({
                    'image_filename': image_filename,
                    'predicted_caption': pred_str,
                    'ground_truth_caption': label_str
                })

    return results

In [14]:
def calculate_global_cer(results):
    total_ed   = 0
    total_refs = 0
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        dist = editdistance.eval(ref, hyp)
        total_ed   += dist
        total_refs += len(ref)
    if total_refs == 0:
        return 0.0
    return total_ed / total_refs

In [15]:
def print_top_worst_samples(results, n=5):
    # Calculate CER for each sample
    results_with_cer = []
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        dist = editdistance.eval(ref, hyp)
        length = len(ref)
        cer = dist / length if length > 0 else 0
        # Copy the record and add cer
        new_r = r.copy()
        new_r['cer'] = cer
        results_with_cer.append(new_r)

    # Sort by CER (descending) and take the top N
    results_with_cer.sort(key=lambda x: x['cer'], reverse=True)
    worst_samples = results_with_cer[:n]

    print(f"\n=== Top {n} Worst Samples by CER ===")
    for i, sample in enumerate(worst_samples, start=1):
        print(f"{i}) Image: {sample['image_filename']}")
        print(f"   CER: {sample['cer']:.4f}")
        print(f"   Predicted       : {sample['predicted_caption']}")
        print(f"   Ground Truth    : {sample['ground_truth_caption']}")
        print()

In [16]:
training_csv = "training_results.csv"
if not os.path.exists(training_csv) or os.path.getsize(training_csv) == 0:
    pd.DataFrame(columns=["model_name", "mode", "epoch1", "epoch2"]).to_csv(training_csv, index=False)

csv_file = "test_cer_results.csv"
if not os.path.exists(csv_file) or os.path.getsize(csv_file) == 0:
    pd.DataFrame(columns=["model_name", "test_cer"]).to_csv(csv_file, index=False)

def log_test_cer(model_name, cer_value):
    df = pd.read_csv(csv_file)
    # Check if model_name exists
    if model_name in df['model_name'].values:
        # Update existing row
        df.loc[df['model_name'] == model_name, 'test_cer'] = cer_value
    else:
        # Add new row - use concat instead of append
        new_row = pd.DataFrame({"model_name": [model_name], "test_cer": [cer_value]})
        df = pd.concat([df, new_row], ignore_index=True)
    
    df.to_csv(csv_file, index=False)
    print(f"Logged {model_name}: {cer_value:.4f}")

In [17]:
cnn_encoder = ResNet18Encoder(pretrained=True)
cnn_decoder = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=cnn_encoder.encoder_dim,  # 512 for ResNet18
    teacher_forcing_ratio=0.5
)

cnn_encoder.to(device)
cnn_decoder.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
encoder_optimizer = optim.Adam(cnn_encoder.parameters(), lr=1e-4)
decoder_optimizer = optim.Adam(cnn_decoder.parameters(), lr=4e-4)

trainer = ImageCaptioningTrainer(
    encoder=cnn_encoder,
    decoder=cnn_decoder,
    criterion=criterion,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="resnet18_encoder"
)

num_epochs = 30
trainer.fit(num_epochs)

# Testing
cnn_encoder.eval()
cnn_decoder.teacher_forcing_ratio = 0.0

results_resnet18 = inference(
    encoder=cnn_encoder,
    decoder=cnn_decoder,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_resnet18 = calculate_global_cer(results_resnet18)
print(f"ResNet18 — Test CER: {cer_resnet18:.4f}")
print_top_worst_samples(results_resnet18, n=5)

# CSV logging
log_test_cer("resnet18_encoder", cer_resnet18)

# THEN do your existing cleanup
del cnn_encoder
del cnn_decoder
del trainer
torch.cuda.empty_cache()
print("Memory cleared")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/yhuang1/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 432MB/s]



Epoch 1/30
[1/30] Train Loss: 1.9060, Train CER: 0.6279 | Val Loss: 1.0704, Val CER: 0.3051

Epoch 2/30
[2/30] Train Loss: 1.0801, Train CER: 0.3032 | Val Loss: 0.5468, Val CER: 0.1315

Epoch 3/30
[3/30] Train Loss: 0.6787, Train CER: 0.1800 | Val Loss: 0.3947, Val CER: 0.1031

Epoch 4/30
[4/30] Train Loss: 0.4782, Train CER: 0.1234 | Val Loss: 0.3686, Val CER: 0.0994

Epoch 5/30
[5/30] Train Loss: 0.3518, Train CER: 0.0898 | Val Loss: 0.3646, Val CER: 0.0951

Epoch 6/30
[6/30] Train Loss: 0.2487, Train CER: 0.0626 | Val Loss: 0.3649, Val CER: 0.0992
Validation loss did not improve. Early stop counter: 1/5

Epoch 7/30
[7/30] Train Loss: 0.1958, Train CER: 0.0506 | Val Loss: 0.3488, Val CER: 0.0922

Epoch 8/30
[8/30] Train Loss: 0.1543, Train CER: 0.0378 | Val Loss: 0.4222, Val CER: 0.1032
Validation loss did not improve. Early stop counter: 1/5

Epoch 9/30
[9/30] Train Loss: 0.1398, Train CER: 0.0346 | Val Loss: 0.3738, Val CER: 0.0828
Validation loss did not improve. Early stop count

In [18]:
cnn_encoder_50 = ResNet50Encoder(pretrained=True)
cnn_decoder_50 = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=cnn_encoder_50.encoder_dim,  # 2048 for ResNet50
    teacher_forcing_ratio=0.5
)

cnn_encoder_50 = cnn_encoder_50.to(device)
cnn_decoder_50 = cnn_decoder_50.to(device)

trainer = ImageCaptioningTrainer(
    encoder=cnn_encoder_50,
    decoder=cnn_decoder_50,
    criterion=criterion,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="resnet50_encoder"
)
num_epochs = 30
trainer.fit(num_epochs)

cnn_encoder_50.eval()
cnn_decoder_50.teacher_forcing_ratio = 0.0

results_resnet50 = inference(
    encoder=cnn_encoder_50,
    decoder=cnn_decoder_50,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_resnet50 = calculate_global_cer(results_resnet50)
print(f"ResNet50 — Test CER: {cer_resnet50:.4f}")
print_top_worst_samples(results_resnet50, n=5)

log_test_cer("resnet50_encoder", cer_resnet50)

# Manually delete references to free GPU memory
del cnn_encoder_50
del cnn_decoder_50
del trainer

# Empty the PyTorch CUDA cache
torch.cuda.empty_cache()
print("Memory cleared")

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/yhuang1/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 448MB/s]



Epoch 1/30
[1/30] Train Loss: 3.6700, Train CER: 0.9449 | Val Loss: 3.6726, Val CER: 0.9805

Epoch 2/30
[2/30] Train Loss: 3.6704, Train CER: 0.9461 | Val Loss: 3.6724, Val CER: 0.9819

Epoch 3/30
[3/30] Train Loss: 3.6695, Train CER: 0.9457 | Val Loss: 3.6736, Val CER: 0.9810
Validation loss did not improve. Early stop counter: 1/5

Epoch 4/30
[4/30] Train Loss: 3.6697, Train CER: 0.9474 | Val Loss: 3.6741, Val CER: 0.9838
Validation loss did not improve. Early stop counter: 2/5

Epoch 5/30
[5/30] Train Loss: 3.6697, Train CER: 0.9465 | Val Loss: 3.6728, Val CER: 0.9810
Validation loss did not improve. Early stop counter: 3/5

Epoch 6/30
[6/30] Train Loss: 3.6701, Train CER: 0.9466 | Val Loss: 3.6733, Val CER: 0.9817
Validation loss did not improve. Early stop counter: 4/5

Epoch 7/30
[7/30] Train Loss: 3.6701, Train CER: 0.9470 | Val Loss: 3.6727, Val CER: 0.9826
Validation loss did not improve. Early stop counter: 5/5
Early stopping triggered.

Training completed in 0h 6m.

Results

In [19]:
encoder_vit_base = ViTEncoder(model_name="vit_base_patch16_224", pretrained=True)
decoder_vit_base = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=encoder_vit_base.encoder_dim,  
    teacher_forcing_ratio=0.5
).to(device)

encoder_vit_base = encoder_vit_base.to(device)
decoder_vit_base = decoder_vit_base.to(device)


criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
encoder_optimizer_base = optim.Adam(encoder_vit_base.parameters(), lr=1e-4)
decoder_optimizer_base = optim.Adam(decoder_vit_base.parameters(), lr=4e-4)

trainer_base = ImageCaptioningTrainer(
    encoder=encoder_vit_base,
    decoder=decoder_vit_base,
    criterion=criterion,                   
    encoder_optimizer=encoder_optimizer_base,
    decoder_optimizer=decoder_optimizer_base,
    train_loader=train_loader,             
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="vit_base_patch16_224"     
)

num_epochs = 30
trainer_base.fit(num_epochs)

encoder_vit_base.eval()
decoder_vit_base.teacher_forcing_ratio = 0.0

results_vit_base = inference(
    encoder=encoder_vit_base,
    decoder=decoder_vit_base,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_vit_base = calculate_global_cer(results_vit_base)
print(f"ViT Base — Test CER: {cer_vit_base:.4f}")

print_top_worst_samples(results_vit_base, n=5)

log_test_cer("vit_base_patch16_224", cer_vit_base)

# Manually delete references to free GPU memory
del encoder_vit_base
del decoder_vit_base
del trainer_base

# Empty the PyTorch CUDA cache
torch.cuda.empty_cache()
print("Memory cleared")

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]


Epoch 1/30
[1/30] Train Loss: 1.9702, Train CER: 0.6551 | Val Loss: 1.1713, Val CER: 0.3202

Epoch 2/30
[2/30] Train Loss: 1.1482, Train CER: 0.3401 | Val Loss: 0.5902, Val CER: 0.1608

Epoch 3/30
[3/30] Train Loss: 0.7176, Train CER: 0.1992 | Val Loss: 0.3750, Val CER: 0.0950

Epoch 4/30
[4/30] Train Loss: 0.5175, Train CER: 0.1379 | Val Loss: 0.3134, Val CER: 0.0874

Epoch 5/30
[5/30] Train Loss: 0.3995, Train CER: 0.1033 | Val Loss: 0.2901, Val CER: 0.0806

Epoch 6/30
[6/30] Train Loss: 0.3277, Train CER: 0.0877 | Val Loss: 0.2912, Val CER: 0.0704
Validation loss did not improve. Early stop counter: 1/5

Epoch 7/30
[7/30] Train Loss: 0.2854, Train CER: 0.0757 | Val Loss: 0.3134, Val CER: 0.0880
Validation loss did not improve. Early stop counter: 2/5

Epoch 8/30
[8/30] Train Loss: 0.2879, Train CER: 0.0727 | Val Loss: 0.2718, Val CER: 0.0729

Epoch 9/30
[9/30] Train Loss: 0.2016, Train CER: 0.0505 | Val Loss: 0.2758, Val CER: 0.0702
Validation loss did not improve. Early stop count

In [20]:
encoder_vit_large = ViTEncoder(model_name="vit_large_patch16_224", pretrained=True)
decoder_vit_large = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=encoder_vit_large.encoder_dim,
    teacher_forcing_ratio=0.5
)

encoder_vit_large = encoder_vit_large.to(device)
decoder_vit_large = decoder_vit_large.to(device)

encoder_optimizer_large = optim.Adam(encoder_vit_large.parameters(), lr=1e-4)
decoder_optimizer_large = optim.Adam(decoder_vit_large.parameters(), lr=4e-4)

trainer_large = ImageCaptioningTrainer(
    encoder=encoder_vit_large,
    decoder=decoder_vit_large,
    criterion=criterion,                    
    encoder_optimizer=encoder_optimizer_large,
    decoder_optimizer=decoder_optimizer_large,
    train_loader=train_loader,              
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="vit_large_patch16_224"     
)

num_epochs = 30
trainer_large.fit(num_epochs)

encoder_vit_large.eval()
decoder_vit_large.teacher_forcing_ratio = 0.0

results_vit_large = inference(
    encoder=encoder_vit_large,
    decoder=decoder_vit_large,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_vit_large = calculate_global_cer(results_vit_large)
print(f"ViT Large — Test CER: {cer_vit_large:.4f}")

print_top_worst_samples(results_vit_large, n=5)

# CSV logging 
log_test_cer("vit_large_patch16_224", cer_vit_large)

# Manually delete references to free GPU memory
del encoder_vit_large
del decoder_vit_large
del trainer_large

# Empty the PyTorch CUDA cache
torch.cuda.empty_cache()
print("Memory cleared")

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]


Epoch 1/30
[1/30] Train Loss: 1.9242, Train CER: 0.6307 | Val Loss: 0.9994, Val CER: 0.2956

Epoch 2/30
[2/30] Train Loss: 1.0784, Train CER: 0.3187 | Val Loss: 0.5151, Val CER: 0.1456

Epoch 3/30
[3/30] Train Loss: 0.6415, Train CER: 0.1827 | Val Loss: 0.3316, Val CER: 0.0884

Epoch 4/30
[4/30] Train Loss: 0.4513, Train CER: 0.1222 | Val Loss: 0.2989, Val CER: 0.0832

Epoch 5/30
[5/30] Train Loss: 0.3414, Train CER: 0.0912 | Val Loss: 0.2633, Val CER: 0.0684

Epoch 6/30
[6/30] Train Loss: 0.2759, Train CER: 0.0730 | Val Loss: 0.2482, Val CER: 0.0616

Epoch 7/30
[7/30] Train Loss: 0.2371, Train CER: 0.0615 | Val Loss: 0.2436, Val CER: 0.0613

Epoch 8/30
[8/30] Train Loss: 0.2072, Train CER: 0.0532 | Val Loss: 0.2479, Val CER: 0.0698
Validation loss did not improve. Early stop counter: 1/5

Epoch 9/30
[9/30] Train Loss: 0.1865, Train CER: 0.0476 | Val Loss: 0.2552, Val CER: 0.0694
Validation loss did not improve. Early stop counter: 2/5

Epoch 10/30
[10/30] Train Loss: 0.1770, Train CE

In [21]:
# Instantiate the individual encoders for ViT-Base
cnn_encoder = ResNet18Encoder(pretrained=True)
vit_encoder = ViTEncoder(model_name="vit_base_patch16_224", pretrained=True)

# Instantiate and move the hybrid encoder to device
hybrid_encoder = HybridEncoder(cnn_encoder, vit_encoder)
hybrid_encoder.to(device)

# Instantiate the decoder with the hybrid encoder's output dimension
hybrid_decoder = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=hybrid_encoder.encoder_dim,  # Combined dimension (e.g., 512 + vit_encoder.encoder_dim)
    teacher_forcing_ratio=0.5
)
hybrid_decoder.to(device)

# Set up loss and optimizers
criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
encoder_optimizer = optim.Adam(hybrid_encoder.parameters(), lr=1e-4)
decoder_optimizer = optim.Adam(hybrid_decoder.parameters(), lr=4e-4)

# Instantiate the trainer with the hybrid encoder and decoder
trainer = ImageCaptioningTrainer(
    encoder=hybrid_encoder,
    decoder=hybrid_decoder,
    criterion=criterion,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="hybrid_cnn_vit_base_encoder"
)

# Train the model
num_epochs = 30
trainer.fit(num_epochs)

# Testing: disable teacher forcing
hybrid_encoder.eval()
hybrid_decoder.teacher_forcing_ratio = 0.0

results_hybrid = inference(
    encoder=hybrid_encoder,
    decoder=hybrid_decoder,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_hybrid = calculate_global_cer(results_hybrid)
print(f"Hybrid CNN+ViT-Base — Test CER: {cer_hybrid:.4f}")
print_top_worst_samples(results_hybrid, n=5)

# Log the results to CSV
log_test_cer("hybrid_cnn_vit_base_encoder", cer_hybrid)

# Cleanup
del hybrid_encoder, hybrid_decoder, trainer
torch.cuda.empty_cache()
print("Memory cleared for ViT-Base experiment")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/yhuang1/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 94.5MB/s]


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]


Epoch 1/30
[1/30] Train Loss: 1.9750, Train CER: 0.6500 | Val Loss: 1.1689, Val CER: 0.3334

Epoch 2/30
[2/30] Train Loss: 1.1627, Train CER: 0.3274 | Val Loss: 0.6739, Val CER: 0.1773

Epoch 3/30
[3/30] Train Loss: 0.7453, Train CER: 0.1954 | Val Loss: 0.4441, Val CER: 0.1282

Epoch 4/30
[4/30] Train Loss: 0.5231, Train CER: 0.1347 | Val Loss: 0.3928, Val CER: 0.1047

Epoch 5/30
[5/30] Train Loss: 0.3801, Train CER: 0.0957 | Val Loss: 0.3590, Val CER: 0.0938

Epoch 6/30
[6/30] Train Loss: 0.2803, Train CER: 0.0682 | Val Loss: 0.3403, Val CER: 0.0874

Epoch 7/30
[7/30] Train Loss: 0.2327, Train CER: 0.0563 | Val Loss: 0.3651, Val CER: 0.0912
Validation loss did not improve. Early stop counter: 1/5

Epoch 8/30
[8/30] Train Loss: 0.1961, Train CER: 0.0486 | Val Loss: 0.3011, Val CER: 0.0774

Epoch 9/30
[9/30] Train Loss: 0.1517, Train CER: 0.0352 | Val Loss: 0.3370, Val CER: 0.0799
Validation loss did not improve. Early stop counter: 1/5

Epoch 10/30
[10/30] Train Loss: 0.1264, Train CE

In [22]:
# Instantiate the individual encoders for ViT-Large
cnn_encoder = ResNet18Encoder(pretrained=True)
vit_encoder = ViTEncoder(model_name="vit_large_patch16_224", pretrained=True)

# Instantiate and move the hybrid encoder to device
hybrid_encoder = HybridEncoder(cnn_encoder, vit_encoder)
hybrid_encoder.to(device)

# Instantiate the decoder with the hybrid encoder's output dimension
hybrid_decoder = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=hybrid_encoder.encoder_dim,  # Combined dimension (e.g., 512 + vit_encoder.encoder_dim)
    teacher_forcing_ratio=0.5
)
hybrid_decoder.to(device)

# Set up loss and optimizers
criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
encoder_optimizer = optim.Adam(hybrid_encoder.parameters(), lr=1e-4)
decoder_optimizer = optim.Adam(hybrid_decoder.parameters(), lr=4e-4)

# Instantiate the trainer with the hybrid encoder and decoder
trainer = ImageCaptioningTrainer(
    encoder=hybrid_encoder,
    decoder=hybrid_decoder,
    criterion=criterion,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="hybrid_cnn_vit_encoder"
)

num_epochs = 30
trainer.fit(num_epochs)

# Testing: disable teacher forcing
hybrid_encoder.eval()
hybrid_decoder.teacher_forcing_ratio = 0.0

results_hybrid = inference(
    encoder=hybrid_encoder,
    decoder=hybrid_decoder,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_hybrid = calculate_global_cer(results_hybrid)
print(f"Hybrid CNN+ViT-Large — Test CER: {cer_hybrid:.4f}")
print_top_worst_samples(results_hybrid, n=5)

# Log the results to CSV
log_test_cer("hybrid_cnn_vit_encoder", cer_hybrid)

# Cleanup
del hybrid_encoder, hybrid_decoder, trainer
torch.cuda.empty_cache()
print("Memory cleared for ViT-Large experiment")




model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]


Epoch 1/30
[1/30] Train Loss: 1.8731, Train CER: 0.6196 | Val Loss: 0.8771, Val CER: 0.2368

Epoch 2/30
[2/30] Train Loss: 0.9443, Train CER: 0.2786 | Val Loss: 0.4623, Val CER: 0.1296

Epoch 3/30
[3/30] Train Loss: 0.5846, Train CER: 0.1612 | Val Loss: 0.3012, Val CER: 0.0735

Epoch 4/30
[4/30] Train Loss: 0.4191, Train CER: 0.1088 | Val Loss: 0.2980, Val CER: 0.0825

Epoch 5/30
[5/30] Train Loss: 0.3248, Train CER: 0.0842 | Val Loss: 0.2865, Val CER: 0.0743

Epoch 6/30
[6/30] Train Loss: 0.2810, Train CER: 0.0714 | Val Loss: 0.2462, Val CER: 0.0693

Epoch 7/30
[7/30] Train Loss: 0.2416, Train CER: 0.0607 | Val Loss: 0.2567, Val CER: 0.0591
Validation loss did not improve. Early stop counter: 1/5

Epoch 8/30
[8/30] Train Loss: 0.1943, Train CER: 0.0489 | Val Loss: 0.2405, Val CER: 0.0614

Epoch 9/30
[9/30] Train Loss: 0.1824, Train CER: 0.0436 | Val Loss: 0.2373, Val CER: 0.0559

Epoch 10/30
[10/30] Train Loss: 0.1534, Train CER: 0.0372 | Val Loss: 0.2431, Val CER: 0.0538
Validation 

In [23]:
# Second model: Swin Small
swin_encoder_small = SwinEncoder(model_name="swin_small_patch4_window7_224", pretrained=True).to(device)

# Dummy forward to set encoder_dim
with torch.no_grad():
    dummy = torch.randn(1, 3, 224, 224, device=device)
    _ = swin_encoder_small(dummy)

# Build decoder with the correct encoder_dim
swin_decoder_small = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=swin_encoder_small.encoder_dim,  
    teacher_forcing_ratio=0.5
).to(device)

# Loss + optimizers
criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
encoder_optimizer = optim.Adam(swin_encoder_small.parameters(), lr=1e-4)
decoder_optimizer = optim.Adam(swin_decoder_small.parameters(), lr=4e-4)

# Trainer
trainer_small = ImageCaptioningTrainer(
    encoder=swin_encoder_small,
    decoder=swin_decoder_small,
    criterion=criterion,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="swin_small_encoder"
)

num_epochs = 30
trainer_small.fit(num_epochs)

swin_encoder_small.eval()
swin_decoder_small.teacher_forcing_ratio = 0.0

results_swin_small = inference(
    encoder=swin_encoder_small,
    decoder=swin_decoder_small,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_swin_small = calculate_global_cer(results_swin_small)
print(f"Swin Small — Test CER: {cer_swin_small:.4f}")

print_top_worst_samples(results_swin_small, n=5)

# CSV logging
log_test_cer("swin_small_encoder", cer_swin_small)

# Manually delete references to free GPU memory
del swin_encoder_small
del swin_decoder_small
del trainer_small
del encoder_optimizer
del decoder_optimizer

# Empty the PyTorch CUDA cache
torch.cuda.empty_cache()
print("Memory cleared")

model.safetensors:   0%|          | 0.00/200M [00:00<?, ?B/s]


Epoch 1/30
[1/30] Train Loss: 2.2225, Train CER: 0.7487 | Val Loss: 1.7182, Val CER: 0.5131

Epoch 2/30
[2/30] Train Loss: 1.9042, Train CER: 0.5678 | Val Loss: 1.5531, Val CER: 0.4494

Epoch 3/30
[3/30] Train Loss: 1.7362, Train CER: 0.4964 | Val Loss: 1.3065, Val CER: 0.3279

Epoch 4/30
[4/30] Train Loss: 1.5807, Train CER: 0.4299 | Val Loss: 1.2472, Val CER: 0.2896

Epoch 5/30
[5/30] Train Loss: 1.4402, Train CER: 0.3777 | Val Loss: 1.0321, Val CER: 0.2445

Epoch 6/30
[6/30] Train Loss: 1.3185, Train CER: 0.3347 | Val Loss: 0.9819, Val CER: 0.2156

Epoch 7/30
[7/30] Train Loss: 1.1905, Train CER: 0.2968 | Val Loss: 0.8573, Val CER: 0.1908

Epoch 8/30
[8/30] Train Loss: 1.0613, Train CER: 0.2601 | Val Loss: 0.7094, Val CER: 0.1593

Epoch 9/30
[9/30] Train Loss: 0.9381, Train CER: 0.2263 | Val Loss: 0.7394, Val CER: 0.1588
Validation loss did not improve. Early stop counter: 1/5

Epoch 10/30
[10/30] Train Loss: 0.8395, Train CER: 0.2014 | Val Loss: 0.6225, Val CER: 0.1355

Epoch 11/3

In [24]:
# Third model: Swin Base
swin_encoder_base = SwinEncoder(model_name="swin_base_patch4_window7_224", pretrained=True).to(device)

# Dummy forward to set encoder_dim
with torch.no_grad():
    dummy = torch.randn(1, 3, 224, 224, device=device)
    _ = swin_encoder_base(dummy)

# Build decoder with the correct encoder_dim
swin_decoder_base = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=swin_encoder_base.encoder_dim,  
    teacher_forcing_ratio=0.5
).to(device)

# Loss + optimizers
criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
encoder_optimizer = optim.Adam(swin_encoder_base.parameters(), lr=8e-5)  # Lower learning rate for larger model
decoder_optimizer = optim.Adam(swin_decoder_base.parameters(), lr=3e-4)

# Trainer
trainer_base = ImageCaptioningTrainer(
    encoder=swin_encoder_base,
    decoder=swin_decoder_base,
    criterion=criterion,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="swin_base_encoder"
)

num_epochs = 30
trainer_base.fit(num_epochs)

swin_encoder_base.eval()
swin_decoder_base.teacher_forcing_ratio = 0.0

results_swin_base = inference(
    encoder=swin_encoder_base,
    decoder=swin_decoder_base,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_swin_base = calculate_global_cer(results_swin_base)
print(f"Swin Base — Test CER: {cer_swin_base:.4f}")

print_top_worst_samples(results_swin_base, n=5)

log_test_cer("swin_base_encoder", cer_swin_base)


# Manually delete references to free GPU memory
del swin_encoder_base
del swin_decoder_base
del trainer_base
del encoder_optimizer
del decoder_optimizer

# Empty the PyTorch CUDA cache
torch.cuda.empty_cache()
print("Memory cleared")

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]


Epoch 1/30
[1/30] Train Loss: 2.3137, Train CER: 0.8551 | Val Loss: 1.8726, Val CER: 0.4961

Epoch 2/30
[2/30] Train Loss: 2.0529, Train CER: 0.6394 | Val Loss: 1.7891, Val CER: 0.4837

Epoch 3/30
[3/30] Train Loss: 2.0186, Train CER: 0.6180 | Val Loss: 1.7764, Val CER: 0.4867

Epoch 4/30
[4/30] Train Loss: 1.9623, Train CER: 0.5910 | Val Loss: 1.6797, Val CER: 0.4769

Epoch 5/30
[5/30] Train Loss: 1.8922, Train CER: 0.5620 | Val Loss: 1.6142, Val CER: 0.4624

Epoch 6/30
[6/30] Train Loss: 1.8104, Train CER: 0.5254 | Val Loss: 1.5687, Val CER: 0.4064

Epoch 7/30
[7/30] Train Loss: 1.7434, Train CER: 0.4936 | Val Loss: 1.4280, Val CER: 0.3594

Epoch 8/30
[8/30] Train Loss: 1.6659, Train CER: 0.4586 | Val Loss: 1.3494, Val CER: 0.3252

Epoch 9/30
[9/30] Train Loss: 1.5834, Train CER: 0.4241 | Val Loss: 1.2852, Val CER: 0.3160

Epoch 10/30
[10/30] Train Loss: 1.5034, Train CER: 0.3853 | Val Loss: 1.2078, Val CER: 0.2786

Epoch 11/30
[11/30] Train Loss: 1.4159, Train CER: 0.3456 | Val Los

In [25]:
# Fourth model: Swin Large
swin_encoder_large = SwinEncoder(model_name="swin_large_patch4_window7_224", pretrained=True).to(device)

# Dummy forward to set encoder_dim
with torch.no_grad():
    dummy = torch.randn(1, 3, 224, 224, device=device)
    _ = swin_encoder_large(dummy)

# Build decoder with the correct encoder_dim
swin_decoder_large = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=swin_encoder_large.encoder_dim,  
    teacher_forcing_ratio=0.5
).to(device)

# Loss + optimizers
criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
encoder_optimizer = optim.Adam(swin_encoder_large.parameters(), lr=5e-5)  # Even lower learning rate for largest model
decoder_optimizer = optim.Adam(swin_decoder_large.parameters(), lr=2e-4)

# Trainer
trainer_large = ImageCaptioningTrainer(
    encoder=swin_encoder_large,
    decoder=swin_decoder_large,
    criterion=criterion,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    model_name="swin_large_encoder"
)

num_epochs = 30
trainer_large.fit(num_epochs)

swin_encoder_large.eval()
swin_decoder_large.teacher_forcing_ratio = 0.0

results_swin_large = inference(
    encoder=swin_encoder_large,
    decoder=swin_decoder_large,
    data_loader=test_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_seq_length=max_label_length_test,
    test_data=test_data
)

cer_swin_large = calculate_global_cer(results_swin_large)
print(f"Swin Large — Test CER: {cer_swin_large:.4f}")

print_top_worst_samples(results_swin_large, n=5)

log_test_cer("swin_large_encoder", cer_swin_large)

# Manually delete references to free GPU memory
del swin_encoder_large
del swin_decoder_large
del trainer_large
del encoder_optimizer
del decoder_optimizer

# Empty the PyTorch CUDA cache
torch.cuda.empty_cache()
print("Memory cleared")

model.safetensors:   0%|          | 0.00/788M [00:00<?, ?B/s]


Epoch 1/30
[1/30] Train Loss: 2.3465, Train CER: 0.8409 | Val Loss: 1.8442, Val CER: 0.5287

Epoch 2/30
[2/30] Train Loss: 2.0448, Train CER: 0.6145 | Val Loss: 1.7828, Val CER: 0.4694

Epoch 3/30
[3/30] Train Loss: 1.9856, Train CER: 0.5869 | Val Loss: 1.7048, Val CER: 0.4688

Epoch 4/30
[4/30] Train Loss: 1.9113, Train CER: 0.5519 | Val Loss: 1.6333, Val CER: 0.3977

Epoch 5/30
[5/30] Train Loss: 1.8199, Train CER: 0.5136 | Val Loss: 1.5385, Val CER: 0.4034

Epoch 6/30
[6/30] Train Loss: 1.7166, Train CER: 0.4657 | Val Loss: 1.3686, Val CER: 0.3236

Epoch 7/30
[7/30] Train Loss: 1.5978, Train CER: 0.4089 | Val Loss: 1.2937, Val CER: 0.3017

Epoch 8/30
[8/30] Train Loss: 1.4785, Train CER: 0.3567 | Val Loss: 1.1596, Val CER: 0.2672

Epoch 9/30
[9/30] Train Loss: 1.3533, Train CER: 0.3095 | Val Loss: 1.0926, Val CER: 0.2487

Epoch 10/30
[10/30] Train Loss: 1.2533, Train CER: 0.2774 | Val Loss: 1.0467, Val CER: 0.2351

Epoch 11/30
[11/30] Train Loss: 1.1578, Train CER: 0.2518 | Val Los