In [32]:
import os
import pandas as pd
import numpy as np
import editdistance
import time 

from PIL import Image
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F

import timm  
from sklearn.model_selection import train_test_split

In [33]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [34]:
# Base project path
BASE_DIR = "/projects/qb36/lontar_project"
DATA_DIR = os.path.join(BASE_DIR, "data")

# Define file paths
# balinese_ground_truth_path = os.path.join(DATA_DIR, 'balinese_transliteration_train.txt')
# images_dir = os.path.join(DATA_DIR, 'balinese_word_train')

In [35]:
def load_image_labels(data_dir, ground_truth_file):
    filenames, labels = [], []
    with open(ground_truth_file, 'r', encoding='utf-8') as file:
        for line in file:
            parts = line.strip().split(';')
            if len(parts) == 2:
                filename, label = parts
                filenames.append(os.path.join(data_dir, filename))
                labels.append(label.lower())
    return filenames, labels

In [36]:
filenames, labels = [], []
for lang in languages:
    img_dir = os.path.join(DATA_DIR, f'{lang}_word_train')
    gt_path = os.path.join(DATA_DIR, f'{lang}_transliteration_train.txt')
    img_files, lbls = load_image_labels(img_dir, gt_path)
    filenames.extend(img_files)
    labels.extend(lbls)

In [37]:
# Building the vocabulary
all_text = ''.join(labels)
unique_chars = sorted(set(all_text))

# Special tokens
char_to_idx = {char: idx for idx, char in enumerate(['<PAD>', '<UNK>', '<SOS>', '<EOS>'] + unique_chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(char_to_idx)
print(f'Vocabulary size: {vocab_size}')

# Encoding labels
def encode_label(label, char_to_idx, max_length):
    encoded = [char_to_idx['<SOS>']] + [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] + [char_to_idx['<EOS>']]
    encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
    return encoded[:max_length]

max_label_length = max(len(label) for label in labels) + 2
encoded_labels = [encode_label(label, char_to_idx, max_label_length) for label in labels]

# Data preparation
data = pd.DataFrame({
    'filename': filenames,
    'label': labels,
    'encoded_label': encoded_labels,
    'label_length': [len(lbl) for lbl in encoded_labels]
})

# Train-validation split
def custom_split(df, test_size=0.1, random_state=42):
    return train_test_split(df, test_size=test_size, random_state=random_state)

train_data, val_data = custom_split(data)
print(f'Training size: {len(train_data)}; Validation size: {len(val_data)}')

Vocabulary size: 41
Training size: 29502; Validation size: 3279


In [7]:
# # base_dir = os.getcwd()

# # # Same paths as your original code
# # ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_train.txt') 
# # images_dir        = os.path.join(base_dir, 'balinese_word_train')


# filenames = []
# labels    = []

# def load_labels(file_path):
#     with open(ground_truth_path, 'r', encoding='utf-8') as file:
#         for line in file:
#             line = line.strip()
#             if line:  # Ensure the line is not empty
#                 parts = line.split(';')
#                 if len(parts) == 2:
#                     filename, label = parts
#                     label = label.lower()
#                     filenames.append(filename)
#                     labels.append(label)
#                 else:
#                     print(f"Skipping malformed line: {line}")

# data = pd.DataFrame({
#     'filename': filenames,
#     'label': labels
# })

# label_counts = data['label'].value_counts()

# all_text = ''.join(data['label'])
# unique_chars = sorted(list(set(all_text)))

# # Create character->index starting from 1
# char_to_idx = {char: idx + 1 for idx, char in enumerate(unique_chars)}
# # Add special tokens
# char_to_idx['<PAD>'] = 0
# char_to_idx['<UNK>'] = len(char_to_idx)
# char_to_idx['<SOS>'] = len(char_to_idx)
# char_to_idx['<EOS>'] = len(char_to_idx)

# # Reverse mapping
# idx_to_char = {v: k for k, v in char_to_idx.items()}

# vocab_size = len(char_to_idx)
# print(f"Vocabulary size: {vocab_size}")

# def encode_label(label, char_to_idx, max_length):
#     """
#     Converts a label (string) into a list of indices with <SOS>, <EOS>, padding, etc.
#     """
#     encoded = (
#         [char_to_idx['<SOS>']] +
#         [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] +
#         [char_to_idx['<EOS>']]
#     )
#     # Pad if needed
#     if len(encoded) < max_length:
#         encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
#     else:
#         encoded = encoded[:max_length]
#     return encoded

# max_label_length = max(len(label) for label in data['label']) + 2  # +2 for <SOS> and <EOS>
# data['encoded_label'] = data['label'].apply(lambda x: encode_label(x, char_to_idx, max_label_length))
# data['label_length']  = data['label'].apply(len)

# rare_labels = label_counts[label_counts < 3].index  # NEW: words that appear <3 times

# def custom_split(df, rare_label_list, test_size=0.1, random_state=42):
#     # Separate rare words from frequent ones
#     rare_df     = df[df['label'].isin(rare_label_list)]
#     non_rare_df = df[~df['label'].isin(rare_label_list)]

#     #  train/val split for non-rare
#     train_nr, val_nr = train_test_split(non_rare_df, test_size=test_size, 
#                                         random_state=random_state)

#     # Combine rare samples entirely into training
#     train_df = pd.concat([train_nr, rare_df], ignore_index=True)
#     # Shuffle after combining
#     train_df = train_df.sample(frac=1, random_state=random_state).reset_index(drop=True)

#     val_df = val_nr.reset_index(drop=True)
#     return train_df, val_df

# # Call custom_split instead of direct train_test_split
# train_data, val_data = custom_split(data, rare_labels, test_size=0.1, random_state=42) 

# print(f"Training size: {len(train_data)}; Validation size: {len(val_data)}")

In [38]:
class MultilingualDataset(Dataset):
    def __init__(self, df, transform=None):
        self.data = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.loc[idx, 'filename']
        label = self.data.loc[idx, 'encoded_label']
        label_length = self.data.loc[idx, 'label_length']

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long)
        return image, label, torch.tensor(label_length, dtype=torch.long)

In [39]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 64
train_dataset = MultilingualDataset(train_data, transform=transform)
val_dataset = MultilingualDataset(val_data, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

print("Data loaders created successfully.")

Data loaders created successfully.


In [40]:
class ResNet18Encoder(nn.Module):
    """
    Encoder that uses a pretrained ResNet18 to extract features of shape 
    [B, H*W, C], which the DecoderRNN can then attend over.
    """

    def __init__(self, pretrained=True):
        super(ResNet18Encoder, self).__init__()
        resnet = models.resnet18(pretrained=pretrained)

        # Remove the classification (fc) layer
        modules = list(resnet.children())[:-2]  # remove the avgpool & fc
        self.cnn = nn.Sequential(*modules)

        # last convolutional block outputs 512 channels
        self.encoder_dim = 512

    def forward(self, x):
        """
        Input shape:  x -> [batch_size, 3, 224, 224]
        Output shape: -> [batch_size, num_patches, encoder_dim]
                       where num_patches = H' * W' from the final feature map
        """
        # pass through ResNet (up to layer4)
        features = self.cnn(x)  # [B, 512, H', W']

        # Flatten the spatial dims
        # shape => [B, 512, H', W'] -> [B, H'*W', 512]
        b, c, h, w = features.shape
        features = features.permute(0, 2, 3, 1)   # [B, H', W', C]
        features = features.reshape(b, -1, c)     # [B, H'*W', C]

        return features

In [8]:
class ViTEncoder(nn.Module):
    """
    A simple ViT encoder that extracts patch embeddings as [batch_size, num_patches, hidden_dim].
    We'll use timm to load a pretrained ViT. Then we use .forward_features() to get a
    feature map of shape [B, C, H', W'] for many timm ViT models, which we flatten.
    """
    def __init__(self, model_name, pretrained=True):
        super(ViTEncoder, self).__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained)
        # Remove or replace the classification head
        self.vit.head = nn.Identity()

        # timm's ViT typically has an embed_dim attribute
        self.encoder_dim = self.vit.embed_dim

    def forward(self, x):
        """
        :param x: [batch_size, 3, 224, 224]
        :return:  [batch_size, num_patches, encoder_dim]
        """
        # forward_features usually returns [B, hidden_dim, H', W'] or [B, hidden_dim]
        feats = self.vit.forward_features(x)  # [B, hidden_dim, 14, 14] for vit_base_patch16_224

        # Flatten the spatial dimensions
        if feats.dim() == 4:  # [B, C, H, W]
            b, c, h, w = feats.shape
            feats = feats.permute(0, 2, 3, 1).reshape(b, -1, c)  # => [B, H*W, C]

        return feats

In [9]:
class SwinEncoder(nn.Module):
    def __init__(self, model_name="swin_small_patch4_window7_224", pretrained=True):
        """
        A simple Swin Transformer encoder that extracts patch embeddings
        as [batch_size, num_patches, hidden_dim]. We'll use timm to load 
        a pretrained Swin model, remove its classification head, then flatten.
        """
        super().__init__()
        self.swin = timm.create_model(model_name, pretrained=pretrained)
        self.swin.head = nn.Identity()

        # We'll assign encoder_dim dynamically after forward
        self.encoder_dim = None

    def forward(self, x):
        """
        :param x: [batch_size, 3, 224, 224]
        :return:  [batch_size, num_patches, encoder_dim]
        """
        feats = self.swin.forward_features(x)            # [B, C, H, W]
        b, c, h, w = feats.shape
        feats = feats.flatten(2).transpose(1, 2)         # [B, H*W, C]
        # Set encoder_dim once (C)
        if self.encoder_dim is None:
            self.encoder_dim = feats.shape[-1]
        return feats




In [10]:
class HybridEncoder(nn.Module):
    def __init__(self, cnn_encoder, vit_encoder):
        super(HybridEncoder, self).__init__()
        self.cnn_encoder = cnn_encoder
        self.vit_encoder = vit_encoder
        # Combined encoder_dim is the sum of both encoder dimensions.
        # (CNN outputs 512 channels; ViT outputs its own embed_dim.)
        self.encoder_dim = cnn_encoder.encoder_dim + vit_encoder.encoder_dim

    def forward(self, x):
        # Get CNN features: expected shape [B, 49, 512]
        cnn_features = self.cnn_encoder(x)

        # Get ViT features: expected shape [B, 197, vit_dim] for vit_large_patch16_224
        vit_features = self.vit_encoder(x)
        # If the ViT output contains a class token, remove it.
        if vit_features.shape[1] == 197:
            vit_features = vit_features[:, 1:, :]  # Now shape: [B, 196, vit_dim]
            B, tokens, D = vit_features.shape
            # Reshape tokens into a 14x14 grid: [B, 14, 14, D]
            vit_features = vit_features.reshape(B, 14, 14, D)
            # Permute to [B, D, 14, 14] for pooling
            vit_features = vit_features.permute(0, 3, 1, 2)
            # Use adaptive pooling to reduce to a 7x7 grid
            vit_features = F.adaptive_avg_pool2d(vit_features, (7, 7))
            # Permute back to [B, 7, 7, D] and flatten to [B, 49, D]
            vit_features = vit_features.permute(0, 2, 3, 1).reshape(B, -1, D)

        # Concatenate the features along the feature dimension (dim=2)
        # cnn_features: [B, 49, 512] and vit_features: [B, 49, vit_dim]
        hybrid_features = torch.cat([cnn_features, vit_features], dim=2)
        return hybrid_features

In [41]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # transform encoder output
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # transform decoder hidden
        self.full_att    = nn.Linear(attention_dim, 1)
        self.relu        = nn.ReLU()
        self.softmax     = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out:    [batch_size, num_patches, encoder_dim]
        decoder_hidden: [batch_size, decoder_dim]
        """
        att1 = self.encoder_att(encoder_out)                  # [batch_size, num_patches, attention_dim]
        att2 = self.decoder_att(decoder_hidden).unsqueeze(1)  # [batch_size, 1, attention_dim]

        # sum -> relu -> full_att -> squeeze -> softmax
        att  = self.full_att(self.relu(att1 + att2)).squeeze(2)  # [batch_size, num_patches]
        alpha = self.softmax(att)
        # Weighted sum of the encoder_out
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # [batch_size, encoder_dim]
        return attention_weighted_encoding, alpha

In [43]:
class DecoderRNN(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=768, teacher_forcing_ratio=0.5):
        super(DecoderRNN, self).__init__()

        self.attention     = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding     = nn.Embedding(vocab_size, embed_dim)
        self.dropout       = nn.Dropout(p=0.5)

        # [embed_dim + encoder_dim] -> decoder_dim
        self.lstm1 = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim)
        # decoder_dim -> decoder_dim
        self.lstm2 = nn.LSTMCell(decoder_dim, decoder_dim)

        # For initializing the hidden states of both LSTM layers
        self.init_h1 = nn.Linear(encoder_dim, decoder_dim)
        self.init_c1 = nn.Linear(encoder_dim, decoder_dim)
        self.init_h2 = nn.Linear(encoder_dim, decoder_dim)
        self.init_c2 = nn.Linear(encoder_dim, decoder_dim)

        # Gating
        self.f_beta  = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()

        # Final linear layer for output vocab
        self.fc = nn.Linear(decoder_dim, vocab_size)

        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.init_weights()

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        # encoder_out: [batch_size, num_patches, encoder_dim]
        mean_encoder_out = encoder_out.mean(dim=1)  # [batch_size, encoder_dim]
        h1 = self.init_h1(mean_encoder_out)         # [batch_size, decoder_dim]
        c1 = self.init_c1(mean_encoder_out)         # [batch_size, decoder_dim]
        h2 = self.init_h2(mean_encoder_out)         # [batch_size, decoder_dim]
        c2 = self.init_c2(mean_encoder_out)         # [batch_size, decoder_dim]
        return (h1, c1, h2, c2)


    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        encoder_out:      [batch_size, num_patches, encoder_dim]
        encoded_captions: [batch_size, max_label_length]
        caption_lengths:  [batch_size, 1]
        """
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out      = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        embeddings = self.embedding(encoded_captions)

        # Initialize hidden states for both LSTM layers
        h1, c1, h2, c2 = self.init_hidden_state(encoder_out)

        decode_lengths    = (caption_lengths - 1).tolist()
        max_decode_length = max(decode_lengths)

        batch_size = encoder_out.size(0)
        vocab_size = self.fc.out_features

        predictions = torch.zeros(batch_size, max_decode_length, vocab_size, device=encoder_out.device)
        alphas      = torch.zeros(batch_size, max_decode_length, encoder_out.size(1), device=encoder_out.device)

        # We'll feed the first token from the input (<SOS>) or from the previous prediction
        prev_tokens = encoded_captions[:, 0].clone()

        for t in range(max_decode_length):
            batch_size_t = sum([l > t for l in decode_lengths])

            attention_weighted_encoding, alpha = self.attention(
                encoder_out[:batch_size_t],
                h1[:batch_size_t]  # use the first LSTM layer's hidden state for attention
            )

            # Apply gating
            gate = self.sigmoid(self.f_beta(h1[:batch_size_t]))
            attention_weighted_encoding = gate * attention_weighted_encoding

            # Teacher forcing?
            use_teacher_forcing = (torch.rand(1).item() < self.teacher_forcing_ratio)
            if use_teacher_forcing:
                current_input = embeddings[:batch_size_t, t, :]
            else:
                current_input = self.embedding(prev_tokens[:batch_size_t].detach())

            # first lstm layer
            h1_next, c1_next = self.lstm1(
                torch.cat([current_input, attention_weighted_encoding], dim=1),
                (h1[:batch_size_t], c1[:batch_size_t])
            )

            # second lstm layer
            h2_next, c2_next = self.lstm2(
                h1_next, (h2[:batch_size_t], c2[:batch_size_t])
            )

            # Use the second LSTM layer's output (h2_next) for final prediction
            preds = self.fc(self.dropout(h2_next))
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :]      = alpha

            # Update prev_tokens with the best predicted token
            _, next_tokens = preds.max(dim=1)
            prev_tokens_ = prev_tokens.clone()
            prev_tokens_[:batch_size_t] = next_tokens.detach()
            prev_tokens = prev_tokens_

            # Update hidden states
            # For samples still in the batch, store the new h1, c1, h2, c2
            h1_new = torch.zeros_like(h1)
            c1_new = torch.zeros_like(c1)
            h2_new = torch.zeros_like(h2)
            c2_new = torch.zeros_like(c2)

            h1_new[:batch_size_t] = h1_next
            c1_new[:batch_size_t] = c1_next
            h2_new[:batch_size_t] = h2_next
            c2_new[:batch_size_t] = c2_next

            h1, c1, h2, c2 = h1_new, c1_new, h2_new, c2_new

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [44]:
class ImageCaptioningTrainer:
    def __init__(self, encoder, decoder, 
                 criterion, encoder_optimizer, decoder_optimizer, 
                 train_loader, val_loader, test_loader, test_data, max_label_length_test,
                 device, char_to_idx, idx_to_char, max_label_length,
                 model_name, csv_filename="training_results.csv"):
        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)
        self.criterion = criterion
        self.encoder_optimizer = encoder_optimizer
        self.decoder_optimizer = decoder_optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.test_data = test_data
        self.max_label_length_test = max_label_length_test
        self.device = device
        self.char_to_idx = char_to_idx
        self.idx_to_char = idx_to_char
        self.max_label_length = max_label_length
        self.model_name = model_name
        self.csv_filename = csv_filename

        self.train_losses = []
        self.val_losses = []
        self.train_cers = []
        self.val_cers = []
        self.test_cers = []

    
    def fit(self, num_epochs):
        start_time = time.time()
        best_test_cer = float('inf')
        early_stop_counter = 0
        early_stop_patience = 5
    
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            train_loss, train_cer = self.train_one_epoch()
            val_loss, val_cer = self.validate_one_epoch()
    
            results = inference(
                encoder=self.encoder,
                decoder=self.decoder,
                data_loader=self.test_loader,
                device=self.device,
                char_to_idx=self.char_to_idx,
                idx_to_char=self.idx_to_char,
                max_seq_length=self.max_label_length_test,
                test_data=self.test_data
            )
            test_cer = calculate_global_cer(results)
    
            print(f"[{epoch+1}/{num_epochs}] "
                  f"Train Loss: {train_loss:.4f}, Train CER: {train_cer:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val CER: {val_cer:.4f} | "
                  f"Test CER: {test_cer:.4f}")
    
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_cers.append(train_cer)
            self.val_cers.append(val_cer)
            self.test_cers.append(test_cer)
    
            if test_cer < best_test_cer:
                best_test_cer = test_cer
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                print(f"Test CER did not improve. Early stop counter: {early_stop_counter}/{early_stop_patience}")
            
            if early_stop_counter >= early_stop_patience:
                print("Early stopping triggered.")
                break
    
        total_time = time.time() - start_time
        hours = int(total_time // 3600)
        minutes = int((total_time % 3600) // 60)
        print(f"\nTraining completed in {hours}h {minutes}m.")
    
        # Save CSV
        num_epochs_recorded = len(self.train_losses)
        epoch_cols = [f"epoch{i+1}" for i in range(num_epochs_recorded)]
    
        new_rows = pd.DataFrame([
            [self.model_name, "training loss"] + self.train_losses,
            [self.model_name, "validation loss"] + self.val_losses,
            [self.model_name, "training cer"] + self.train_cers,
            [self.model_name, "validation cer"] + self.val_cers,
            [self.model_name, "test cer"] + self.test_cers
        ], columns=["model_name", "mode"] + epoch_cols)
    
        if os.path.exists(self.csv_filename):
            df_existing = pd.read_csv(self.csv_filename)
            df_existing = df_existing[df_existing["model_name"] != self.model_name]
            df_updated = pd.concat([df_existing, new_rows], ignore_index=True)
        else:
            df_updated = new_rows
    
        df_updated[epoch_cols] = np.floor(df_updated[epoch_cols] * 100) / 100 
        df_updated.to_csv(self.csv_filename, index=False)
        print(f"\nResults have been written to: {self.csv_filename}")

    def train_one_epoch(self):
        self.encoder.train()
        self.decoder.train()
        running_loss           = 0.0
        total_edit_distance    = 0
        total_ref_length       = 0

        for batch_idx, (images, labels, label_lengths) in enumerate(self.train_loader):
            images        = images.to(self.device, non_blocking=True)
            labels        = labels.to(self.device, non_blocking=True)
            label_lengths = label_lengths.to(self.device, non_blocking=True)

            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            encoder_out   = self.encoder(images)
            caption_lengths = torch.tensor(
                [self.max_label_length] * labels.size(0)
            ).unsqueeze(1).to(self.device)

            outputs, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
                encoder_out, labels, caption_lengths
            )

            # Targets = encoded captions without the <SOS>
            targets = encoded_captions[:, 1:]

            # Flatten for loss
            outputs_flat = outputs.view(-1, self.decoder.fc.out_features)
            targets_flat = targets.contiguous().view(-1)

            loss = self.criterion(outputs_flat, targets_flat)
            loss.backward()

            self.decoder_optimizer.step()
            self.encoder_optimizer.step()

            running_loss += loss.item()

            # Compute CER for the batch (global style)
            batch_size = labels.size(0)
            _, preds_flat = torch.max(outputs_flat, dim=1)
            preds_seq = preds_flat.view(batch_size, -1)

            for i in range(batch_size):
                pred_indices   = preds_seq[i].detach().cpu().numpy()
                target_indices = targets[i].detach().cpu().numpy()

                mask          = (target_indices != self.char_to_idx['<PAD>'])
                pred_indices  = pred_indices[mask]
                target_indices= target_indices[mask]

                pred_chars    = [self.idx_to_char.get(idx, '') for idx in pred_indices]
                target_chars  = [self.idx_to_char.get(idx, '') for idx in target_indices]
                pred_str      = ''.join(pred_chars)
                target_str    = ''.join(target_chars)

                edit_dist           = editdistance.eval(pred_str, target_str)
                total_edit_distance += edit_dist
                total_ref_length    += len(target_str)

            # if (batch_idx + 1) % 50 == 0:
            #     print(f'Batch {batch_idx + 1}/{len(self.train_loader)} - Loss: {loss.item():.4f}')

        avg_loss = running_loss / len(self.train_loader)
        avg_cer  = total_edit_distance / total_ref_length if total_ref_length > 0 else 0.0
        return avg_loss, avg_cer

    def validate_one_epoch(self, top_n=5):
        self.encoder.eval()
        self.decoder.eval()
        running_loss         = 0.0
        total_edit_distance  = 0
        total_ref_length     = 0

        # each sample’s CER
        sample_cer_info = []

        with torch.no_grad():
            for batch_idx, (images, labels, label_lengths) in enumerate(self.val_loader):
                images        = images.to(self.device, non_blocking=True)
                labels        = labels.to(self.device, non_blocking=True)
                label_lengths = label_lengths.to(self.device, non_blocking=True)

                encoder_out = self.encoder(images)
                caption_lengths = torch.tensor(
                    [self.max_label_length] * labels.size(0)
                ).unsqueeze(1).to(self.device)

                outputs, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
                    encoder_out, labels, caption_lengths
                )
                targets = encoded_captions[:, 1:]

                outputs_flat = outputs.view(-1, self.decoder.fc.out_features)
                targets_flat = targets.contiguous().view(-1)

                loss = self.criterion(outputs_flat, targets_flat)
                running_loss += loss.item()

                batch_size = labels.size(0)
                _, preds_flat = torch.max(outputs_flat, dim=1)
                preds_seq = preds_flat.view(batch_size, -1)

                for i in range(batch_size):
                    pred_indices   = preds_seq[i].detach().cpu().numpy()
                    target_indices = targets[i].detach().cpu().numpy()

                    mask           = (target_indices != self.char_to_idx['<PAD>'])
                    pred_indices   = pred_indices[mask]
                    target_indices = target_indices[mask]

                    pred_chars   = [self.idx_to_char.get(idx, '') for idx in pred_indices]
                    target_chars = [self.idx_to_char.get(idx, '') for idx in target_indices]
                    pred_str     = ''.join(pred_chars)
                    target_str   = ''.join(target_chars)

                    edit_dist = editdistance.eval(pred_str, target_str)
                    ref_len   = len(target_str)
                    cer       = edit_dist / ref_len if ref_len > 0 else 0
    
                    total_edit_distance += edit_dist
                    total_ref_length    += ref_len
    
                    # Store sample info
                    # sample_cer_info.append({
                    #     "pred": pred_str,
                    #     "gt": target_str,
                    #     "cer": cer
                    # })

                    # Print a few samples from the 1st batch
                    # if batch_idx == 0 and i < 3:
                    #     print(f"Sample {i + 1}:")
                    #     print(f"Predicted: {pred_str}")
                    #     print(f"Target   : {target_str}\n")

        avg_loss = running_loss / len(self.val_loader)
        avg_cer  = total_edit_distance / total_ref_length if total_ref_length > 0 else 0.0

        # Sort by CER descending
        sample_cer_info.sort(key=lambda x: x["cer"], reverse=True)
        # Take top_n
        worst_samples = sample_cer_info[:top_n]
    
        # print(f"\n=== Top {top_n} Worst Samples by CER ===")
        # for idx, sample in enumerate(worst_samples):
        #     print(f"[{idx+1}] CER: {sample['cer']:.3f}")
        #     print(f"   Predicted: {sample['pred']}")
        #     print(f"   Ground Truth: {sample['gt']}\n")
       
        return avg_loss, avg_cer


In [45]:
test_ground_truth_path = os.path.join(DATA_DIR, 'balinese_transliteration_test.txt')
test_images_dir        = os.path.join(DATA_DIR, 'balinese_word_test')

test_filenames = []
test_labels    = []

with open(test_ground_truth_path, 'r', encoding='utf-8') as file:
    for line in file:
        line = line.strip()
        if line:
            parts = line.split(';')
            if len(parts) == 2:
                filename, label = parts
                label = label.lower()
                test_filenames.append(filename)
                test_labels.append(label)
            else:
                print(f"Skipping malformed line: {line}")

test_data = pd.DataFrame({
    'filename': test_filenames,
    'label': test_labels
})

# Check for unknown chars in test set
test_chars = set(''.join(test_data['label']))
unknown_chars = test_chars - set(char_to_idx.keys())
print(f"Unknown characters in test labels: {unknown_chars}")

# Encode test labels
max_label_length_test = max(len(lbl) for lbl in test_data['label']) + 2
def encode_label_test(label, char_to_idx, max_length):
    encoded = (
        [char_to_idx['<SOS>']] +
        [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] +
        [char_to_idx['<EOS>']]
    )
    if len(encoded) > max_length:
        encoded = encoded[:max_length]
    else:
        encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
    return encoded

test_data['encoded_label'] = test_data['label'].apply(lambda x: encode_label_test(x, char_to_idx, max_label_length_test))
test_data['label_length']  = test_data['label'].apply(len)

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5)
    )
])

test_dataset = MultilingualDataset(test_data, test_images_dir, transform=test_transform)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

Unknown characters in test labels: set()


TypeError: __init__() got multiple values for argument 'transform'

In [46]:
def inference(encoder, decoder, data_loader, device, char_to_idx, idx_to_char, max_seq_length, test_data):
    encoder.eval()
    decoder.eval()

    eos_idx = char_to_idx['<EOS>']
    results = []

    with torch.no_grad():
        for batch_idx, (images, labels, label_lengths) in enumerate(data_loader):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            batch_size  = images.size(0)
            encoder_out = encoder(images)  # [B, num_patches, encoder_dim]

            # Init LSTM state
            h1, c1, h2, c2 = decoder.init_hidden_state(encoder_out)

            # Start tokens (all <SOS>)
            inputs = torch.full(
                (batch_size,),
                fill_value=char_to_idx['<SOS>'],
                dtype=torch.long,
                device=device
            )

            all_preds = []

            for _ in range(max_seq_length):
                # Embedding
                embeddings = decoder.embedding(inputs)

                # Attention
                attention_weighted_encoding, _ = decoder.attention(encoder_out, h1)

                # Gating
                gate = decoder.sigmoid(decoder.f_beta(h1))
                attention_weighted_encoding = gate * attention_weighted_encoding

                # Pass through LSTM layers
                h1, c1 = decoder.lstm1(
                    torch.cat([embeddings, attention_weighted_encoding], dim=1),
                    (h1, c1)
                )
                h2, c2 = decoder.lstm2(h1, (h2, c2))

                # Get predicted token
                preds = decoder.fc(decoder.dropout(h2))  # [batch_size, vocab_size]
                _, preds_idx = preds.max(dim=1)

                # Feed next token
                all_preds.append(preds_idx.cpu().numpy())
                inputs = preds_idx

            # Reformat predictions to [batch_size, max_seq_length]
            all_preds = np.array(all_preds).T

            for i in range(batch_size):
                pred_indices = all_preds[i]

                # Stop at <EOS> if present
                if eos_idx in pred_indices:
                    first_eos = np.where(pred_indices == eos_idx)[0][0]
                    pred_indices = pred_indices[:first_eos]

                # Convert token indices -> string
                pred_chars = [idx_to_char.get(idx, '') for idx in pred_indices]
                pred_str   = ''.join(pred_chars)

                # Process ground truth
                label_indices = labels[i].cpu().numpy()
                # remove <SOS>
                label_indices = label_indices[1:]

                if eos_idx in label_indices:
                    eos_pos = np.where(label_indices == eos_idx)[0][0]
                    label_indices = label_indices[:eos_pos]
                else:
                    # remove <PAD> if present
                    label_indices = label_indices[label_indices != char_to_idx['<PAD>']]

                label_chars = [idx_to_char.get(idx, '') for idx in label_indices]
                label_str   = ''.join(label_chars)

                global_idx    = batch_idx * batch_size + i
                image_filename= test_data.iloc[global_idx]['filename']

                results.append({
                    'image_filename': image_filename,
                    'predicted_caption': pred_str,
                    'ground_truth_caption': label_str
                })

    return results

In [47]:
def calculate_global_cer(results):
    total_ed   = 0
    total_refs = 0
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        dist = editdistance.eval(ref, hyp)
        total_ed   += dist
        total_refs += len(ref)
    if total_refs == 0:
        return 0.0
    return total_ed / total_refs

In [48]:
def print_top_worst_samples(results, n=10):
    # Calculate CER for each sample
    results_with_cer = []
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        dist = editdistance.eval(ref, hyp)
        length = len(ref)
        cer = dist / length if length > 0 else 0
        # Copy the record and add cer
        new_r = r.copy()
        new_r['cer'] = cer
        results_with_cer.append(new_r)

    # Sort by CER (descending) and take the top N
    results_with_cer.sort(key=lambda x: x['cer'], reverse=True)
    worst_samples = results_with_cer[:n]

    print(f"\n=== Top {n} Worst Samples by CER ===")
    for i, sample in enumerate(worst_samples, start=1):
        print(f"{i}) Image: {sample['image_filename']}")
        print(f"   CER: {sample['cer']:.4f}")
        print(f"   Predicted       : {sample['predicted_caption']}")
        print(f"   Ground Truth    : {sample['ground_truth_caption']}")
        print()

In [49]:
# Ensure CSV files exist
training_csv = "training_results.csv"
if not os.path.exists(training_csv) or os.path.getsize(training_csv) == 0:
    pd.DataFrame(columns=["model_name", "mode", "epoch1", "epoch2"]).to_csv(training_csv, index=False)

csv_file = "test_cer_results.csv"
if not os.path.exists(csv_file) or os.path.getsize(csv_file) == 0:
    pd.DataFrame(columns=["model_name", "test_cer"]).to_csv(csv_file, index=False)

def log_test_cer(model_name, cer_value):
    """
    Logs or updates the test CER for a given model, rounding values to 4 decimals.
    """
    df = pd.read_csv(csv_file)
    
    # Round the new CER value to 4 decimals
    cer_rounded = round(cer_value, 4)
    
    if model_name in df['model_name'].values:
        # Update existing row
        df.loc[df['model_name'] == model_name, 'test_cer'] = cer_rounded
    else:
        # Add new row
        new_row = pd.DataFrame({
            "model_name": [model_name],
            "test_cer":   [cer_rounded]
        })
        df = pd.concat([df, new_row], ignore_index=True)
    
    # Ensure all stored values are rounded to 4 decimals
    df['test_cer'] = df['test_cer'].round(4)
    
    # Save back to CSV
    df.to_csv(csv_file, index=False)
    print(f"Logged {model_name}: {cer_rounded:.4f}")

In [50]:
def run_training_pipeline(encoder_class,encoder_kwargs, model_name,vocab_size,encoder_lr, decoder_lr,train_loader,
                          val_loader,test_loader,char_to_idx,idx_to_char,max_label_length,max_label_length_test,test_data,
                          device, num_epochs=100):
    #build encoder & grab its dimension
    encoder = encoder_class(**encoder_kwargs).to(device)
    # if encoder_dim is None (e.g. Swin), prime it with a dummy batch
    if encoder.encoder_dim is None:
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224, device=device)
            _ = encoder(dummy)
    enc_dim = encoder.encoder_dim

    #build decoder 
    decoder = DecoderRNN(
        attention_dim=256,
        embed_dim=256,
        decoder_dim=512,
        vocab_size=vocab_size,
        encoder_dim=enc_dim,
        teacher_forcing_ratio=0.5
    ).to(device)

    # loss, optimizers, trainer
    criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=encoder_lr)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=decoder_lr)

    # trainer = ImageCaptioningTrainer(
    #     encoder=encoder,
    #     decoder=decoder,
    #     criterion=criterion,
    #     encoder_optimizer=encoder_optimizer,
    #     decoder_optimizer=decoder_optimizer,
    #     train_loader=train_loader,
    #     val_loader=val_loader,
    #     device=device,
    #     char_to_idx=char_to_idx,
    #     idx_to_char=idx_to_char,
    #     max_label_length=max_label_length,
    #     model_name=model_name)
    trainer = ImageCaptioningTrainer(
        encoder=encoder,
        decoder=decoder,
        criterion=criterion,
        encoder_optimizer=encoder_optimizer,
        decoder_optimizer=decoder_optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,  
        test_data=test_data,      
        max_label_length_test=max_label_length_test,  
        device=device,
        char_to_idx=char_to_idx,
        idx_to_char=idx_to_char,
        max_label_length=max_label_length,
        model_name=model_name
    )

    trainer.fit(num_epochs)

    encoder.eval()
    decoder.teacher_forcing_ratio = 0.0

    results = inference(
        encoder=encoder,
        decoder=decoder,
        data_loader=test_loader,
        device=device,
        char_to_idx=char_to_idx,
        idx_to_char=idx_to_char,
        max_seq_length=max_label_length_test,
        test_data=test_data
    )

    cer = calculate_global_cer(results)
    print(f"{model_name} — Test CER: {cer:.4f}")
    print_top_worst_samples(results, n=5)
    log_test_cer(model_name, cer)

    del encoder, decoder, trainer
    torch.cuda.empty_cache()
    print(f"Memory cleared for {model_name}")

In [None]:
run_training_pipeline(
    encoder_class = ResNet18Encoder,
    encoder_kwargs= {},                              
    model_name    = "resnet18_encoder",
    vocab_size    = vocab_size,
    encoder_lr    = 1e-4,
    decoder_lr    = 4e-4,
    train_loader  = train_loader,
    val_loader    = val_loader,
    test_loader   = test_loader,
    char_to_idx   = char_to_idx,
    idx_to_char   = idx_to_char,
    max_label_length      = max_label_length,
    max_label_length_test = max_label_length_test,
    test_data      = test_data,
    device         = device)


Epoch 1/100




[1/100] Train Loss: 2.0083, Train CER: 0.7437 | Val Loss: 1.4585, Val CER: 0.5184 | Test CER: 0.5641

Epoch 2/100
[2/100] Train Loss: 1.1719, Train CER: 0.3896 | Val Loss: 0.9481, Val CER: 0.3237 | Test CER: 0.3936

Epoch 3/100


In [21]:
run_training_pipeline(
    encoder_class = ViTEncoder,
    encoder_kwargs= {"model_name":"vit_base_patch16_224","pretrained":True},
    model_name    = "vit_base_patch16_224",
    vocab_size    = vocab_size,
    encoder_lr    = 1e-4,
    decoder_lr    = 4e-4,
    train_loader  = train_loader,
    val_loader    = val_loader,
    test_loader   = test_loader,
    char_to_idx   = char_to_idx,
    idx_to_char   = idx_to_char,
    max_label_length      = max_label_length,
    max_label_length_test = max_label_length_test,
    test_data      = test_data,
    device         = device)


Epoch 1/100
[1/100] Train Loss: 2.1396, Train CER: 0.7234 | Val Loss: 1.4418, Val CER: 0.3835 | Test CER: 0.6478

Epoch 2/100
[2/100] Train Loss: 1.3897, Train CER: 0.4306 | Val Loss: 0.7124, Val CER: 0.1980 | Test CER: 0.3452

Epoch 3/100
[3/100] Train Loss: 0.8715, Train CER: 0.2544 | Val Loss: 0.4056, Val CER: 0.1177 | Test CER: 0.2278

Epoch 4/100
[4/100] Train Loss: 0.6361, Train CER: 0.1764 | Val Loss: 0.3208, Val CER: 0.0992 | Test CER: 0.1948

Epoch 5/100
[5/100] Train Loss: 0.4953, Train CER: 0.1335 | Val Loss: 0.2361, Val CER: 0.0660 | Test CER: 0.1727

Epoch 6/100
[6/100] Train Loss: 0.4371, Train CER: 0.1153 | Val Loss: 0.2708, Val CER: 0.0671 | Test CER: 0.1700

Epoch 7/100
[7/100] Train Loss: 0.3879, Train CER: 0.1041 | Val Loss: 0.2272, Val CER: 0.0592 | Test CER: 0.1437

Epoch 8/100
[8/100] Train Loss: 0.3178, Train CER: 0.0803 | Val Loss: 0.2144, Val CER: 0.0630 | Test CER: 0.1396

Epoch 9/100
[9/100] Train Loss: 0.2816, Train CER: 0.0725 | Val Loss: 0.1892, Val CER: 

In [22]:
run_training_pipeline(
    encoder_class = ViTEncoder,
    encoder_kwargs= {"model_name": "vit_large_patch16_224", "pretrained": True},
    model_name    = "vit_large_patch16_224",
    vocab_size    = vocab_size,
    encoder_lr    = 1e-4,
    decoder_lr    = 4e-4,
    train_loader  = train_loader,
    val_loader    = val_loader,
    test_loader   = test_loader,
    char_to_idx   = char_to_idx,
    idx_to_char   = idx_to_char,
    max_label_length      = max_label_length,
    max_label_length_test = max_label_length_test,
    test_data      = test_data,
    device         = device)


Epoch 1/100
[1/100] Train Loss: 2.0230, Train CER: 0.6837 | Val Loss: 1.1348, Val CER: 0.4194 | Test CER: 0.4678

Epoch 2/100
[2/100] Train Loss: 1.1678, Train CER: 0.3627 | Val Loss: 0.5553, Val CER: 0.1762 | Test CER: 0.3010

Epoch 3/100
[3/100] Train Loss: 0.7450, Train CER: 0.2168 | Val Loss: 0.3342, Val CER: 0.1001 | Test CER: 0.2110

Epoch 4/100
[4/100] Train Loss: 0.5427, Train CER: 0.1501 | Val Loss: 0.2868, Val CER: 0.0844 | Test CER: 0.1774

Epoch 5/100
[5/100] Train Loss: 0.4189, Train CER: 0.1148 | Val Loss: 0.2487, Val CER: 0.0707 | Test CER: 0.1463

Epoch 6/100
[6/100] Train Loss: 0.3530, Train CER: 0.0924 | Val Loss: 0.2456, Val CER: 0.0658 | Test CER: 0.1473
Test CER did not improve. Early stop counter: 1/5

Epoch 7/100
[7/100] Train Loss: 0.3102, Train CER: 0.0818 | Val Loss: 0.2199, Val CER: 0.0658 | Test CER: 0.1482
Test CER did not improve. Early stop counter: 2/5

Epoch 8/100
[8/100] Train Loss: 0.2700, Train CER: 0.0705 | Val Loss: 0.2035, Val CER: 0.0506 | Test 

In [23]:
run_training_pipeline(
    encoder_class = HybridEncoder,
    encoder_kwargs= {
        "cnn_encoder": ResNet18Encoder(pretrained=True),
        "vit_encoder": ViTEncoder(model_name="vit_base_patch16_224", pretrained=True)
    },
    model_name    = "hybrid_cnn_vit_base_encoder",
    vocab_size    = vocab_size,
    encoder_lr    = 1e-4,
    decoder_lr    = 4e-4,
    train_loader  = train_loader,
    val_loader    = val_loader,
    test_loader   = test_loader,
    char_to_idx   = char_to_idx,
    idx_to_char   = idx_to_char,
    max_label_length      = max_label_length,
    max_label_length_test = max_label_length_test,
    test_data      = test_data,
    device         = device)




Epoch 1/100
[1/100] Train Loss: 2.0932, Train CER: 0.6990 | Val Loss: 1.2920, Val CER: 0.3978 | Test CER: 0.5327

Epoch 2/100
[2/100] Train Loss: 1.3350, Train CER: 0.4082 | Val Loss: 0.6659, Val CER: 0.1602 | Test CER: 0.3497

Epoch 3/100
[3/100] Train Loss: 0.9057, Train CER: 0.2584 | Val Loss: 0.4504, Val CER: 0.1095 | Test CER: 0.2627

Epoch 4/100
[4/100] Train Loss: 0.6700, Train CER: 0.1819 | Val Loss: 0.3908, Val CER: 0.1009 | Test CER: 0.2275

Epoch 5/100
[5/100] Train Loss: 0.5363, Train CER: 0.1417 | Val Loss: 0.2907, Val CER: 0.0752 | Test CER: 0.1824

Epoch 6/100
[6/100] Train Loss: 0.4387, Train CER: 0.1141 | Val Loss: 0.3141, Val CER: 0.0931 | Test CER: 0.1951
Test CER did not improve. Early stop counter: 1/5

Epoch 7/100
[7/100] Train Loss: 0.3710, Train CER: 0.0954 | Val Loss: 0.2351, Val CER: 0.0642 | Test CER: 0.1651

Epoch 8/100
[8/100] Train Loss: 0.3179, Train CER: 0.0826 | Val Loss: 0.2575, Val CER: 0.0628 | Test CER: 0.1599

Epoch 9/100
[9/100] Train Loss: 0.284

In [24]:
run_training_pipeline(
    encoder_class = HybridEncoder,
    encoder_kwargs= {
        "cnn_encoder": ResNet18Encoder(pretrained=True),
        "vit_encoder": ViTEncoder(model_name="vit_large_patch16_224", pretrained=True)
    },
    model_name    = "hybrid_cnn_vit_encoder",
    vocab_size    = vocab_size,
    encoder_lr    = 1e-4,
    decoder_lr    = 4e-4,
    train_loader  = train_loader,
    val_loader    = val_loader,
    test_loader   = test_loader,
    char_to_idx   = char_to_idx,
    idx_to_char   = idx_to_char,
    max_label_length      = max_label_length,
    max_label_length_test = max_label_length_test,
    test_data      = test_data,
    device         = device)




Epoch 1/100
[1/100] Train Loss: 2.0939, Train CER: 0.7058 | Val Loss: 1.2399, Val CER: 0.3603 | Test CER: 0.5064

Epoch 2/100
[2/100] Train Loss: 1.2878, Train CER: 0.4050 | Val Loss: 0.6466, Val CER: 0.1887 | Test CER: 0.3357

Epoch 3/100
[3/100] Train Loss: 0.8379, Train CER: 0.2507 | Val Loss: 0.4145, Val CER: 0.1103 | Test CER: 0.2478

Epoch 4/100
[4/100] Train Loss: 0.5956, Train CER: 0.1680 | Val Loss: 0.3330, Val CER: 0.0937 | Test CER: 0.1969

Epoch 5/100
[5/100] Train Loss: 0.4610, Train CER: 0.1257 | Val Loss: 0.2497, Val CER: 0.0771 | Test CER: 0.1634

Epoch 6/100
[6/100] Train Loss: 0.3711, Train CER: 0.0992 | Val Loss: 0.2188, Val CER: 0.0546 | Test CER: 0.1534

Epoch 7/100
[7/100] Train Loss: 0.3147, Train CER: 0.0835 | Val Loss: 0.1868, Val CER: 0.0532 | Test CER: 0.1354

Epoch 8/100
[8/100] Train Loss: 0.2813, Train CER: 0.0757 | Val Loss: 0.2086, Val CER: 0.0570 | Test CER: 0.1357
Test CER did not improve. Early stop counter: 1/5

Epoch 9/100
[9/100] Train Loss: 0.244

In [None]:
run_training_pipeline(
    encoder_class=SwinEncoder,
    encoder_kwargs={"model_name":"swin_small_patch4_window7_224","pretrained":True},
    model_name="swin_small_encoder",
    vocab_size=vocab_size,
    encoder_lr=1e-4,
    decoder_lr=4e-4,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    max_label_length_test=max_label_length_test,
    test_data=test_data,
    device=device)


Epoch 1/100
[1/100] Train Loss: 2.3823, Train CER: 0.8995 | Val Loss: 1.9020, Val CER: 0.6407 | Test CER: 0.7088

Epoch 2/100
[2/100] Train Loss: 2.0552, Train CER: 0.6276 | Val Loss: 1.8092, Val CER: 0.4584 | Test CER: 0.7408
Test CER did not improve. Early stop counter: 1/5

Epoch 3/100
[3/100] Train Loss: 1.9867, Train CER: 0.5952 | Val Loss: 1.7910, Val CER: 0.4578 | Test CER: 0.7592
Test CER did not improve. Early stop counter: 2/5

Epoch 4/100
[4/100] Train Loss: 1.9236, Train CER: 0.5719 | Val Loss: 1.6809, Val CER: 0.4856 | Test CER: 0.6874

Epoch 5/100
[5/100] Train Loss: 1.8307, Train CER: 0.5312 | Val Loss: 1.5610, Val CER: 0.3943 | Test CER: 0.6616

Epoch 6/100
[6/100] Train Loss: 1.7445, Train CER: 0.4980 | Val Loss: 1.3811, Val CER: 0.3518 | Test CER: 0.6267

Epoch 7/100
[7/100] Train Loss: 1.6613, Train CER: 0.4675 | Val Loss: 1.2994, Val CER: 0.3073 | Test CER: 0.6037

Epoch 8/100
[8/100] Train Loss: 1.5546, Train CER: 0.4258 | Val Loss: 1.1579, Val CER: 0.2847 | Test 

In [None]:
run_training_pipeline(
    encoder_class = SwinEncoder,
    encoder_kwargs= {"model_name": "swin_base_patch4_window7_224", "pretrained": True},
    model_name    = "swin_base_encoder",
    vocab_size    = vocab_size,
    encoder_lr    = 8e-5,    # lower LR for the larger Swin
    decoder_lr    = 3e-4,
    train_loader  = train_loader,
    val_loader    = val_loader,
    test_loader   = test_loader,
    char_to_idx   = char_to_idx,
    idx_to_char   = idx_to_char,
    max_label_length      = max_label_length,
    max_label_length_test = max_label_length_test,
    test_data      = test_data,
    device         = device)

In [None]:
run_training_pipeline(
    encoder_class = SwinEncoder,
    encoder_kwargs= {"model_name": "swin_large_patch4_window7_224", "pretrained": True},
    model_name    = "swin_large_encoder",
    vocab_size    = vocab_size,
    encoder_lr    = 5e-5,
    decoder_lr    = 2e-4,
    train_loader  = train_loader,
    val_loader    = val_loader,
    test_loader   = test_loader,
    char_to_idx   = char_to_idx,
    idx_to_char   = idx_to_char,
    max_label_length      = max_label_length,
    max_label_length_test = max_label_length_test,
    test_data      = test_data,
    device         = device)