In [1]:
import os
import pandas as pd
import numpy as np
import editdistance

from PIL import Image
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm  

from sklearn.model_selection import train_test_split

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_dir = os.getcwd()

# Same paths as your original code
ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_train.txt') 
images_dir        = os.path.join(base_dir, 'balinese_word_train')

filenames = []
labels    = []

with open(ground_truth_path, 'r', encoding='utf-8') as file:
    for line in file:
        line = line.strip()
        if line:  # Ensure the line is not empty
            parts = line.split(';')
            if len(parts) == 2:
                filename, label = parts
                label = label.lower()
                filenames.append(filename)
                labels.append(label)
            else:
                print(f"Skipping malformed line: {line}")

data = pd.DataFrame({
    'filename': filenames,
    'label': labels
})

label_counts = data['label'].value_counts()

all_text = ''.join(data['label'])
unique_chars = sorted(list(set(all_text)))

# Create character->index starting from 1
char_to_idx = {char: idx + 1 for idx, char in enumerate(unique_chars)}
# Add special tokens
char_to_idx['<PAD>'] = 0
char_to_idx['<UNK>'] = len(char_to_idx)
char_to_idx['<SOS>'] = len(char_to_idx)
char_to_idx['<EOS>'] = len(char_to_idx)

# Reverse mapping
idx_to_char = {v: k for k, v in char_to_idx.items()}

vocab_size = len(char_to_idx)
print(f"Vocabulary size: {vocab_size}")

def encode_label(label, char_to_idx, max_length):
    """
    Converts a label (string) into a list of indices with <SOS>, <EOS>, padding, etc.
    """
    encoded = (
        [char_to_idx['<SOS>']] +
        [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] +
        [char_to_idx['<EOS>']]
    )
    # Pad if needed
    if len(encoded) < max_length:
        encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
    else:
        encoded = encoded[:max_length]
    return encoded

max_label_length = max(len(label) for label in data['label']) + 2  # +2 for <SOS> and <EOS>
data['encoded_label'] = data['label'].apply(lambda x: encode_label(x, char_to_idx, max_label_length))
data['label_length']  = data['label'].apply(len)

rare_labels = label_counts[label_counts < 3].index  # NEW: words that appear <3 times

def custom_split(df, rare_label_list, test_size=0.1, random_state=42):
    # Separate rare words from frequent ones
    rare_df     = df[df['label'].isin(rare_label_list)]
    non_rare_df = df[~df['label'].isin(rare_label_list)]

    #  train/val split for non-rare
    train_nr, val_nr = train_test_split(non_rare_df, test_size=test_size, 
                                        random_state=random_state)

    # Combine rare samples entirely into training
    train_df = pd.concat([train_nr, rare_df], ignore_index=True)
    # Shuffle after combining
    train_df = train_df.sample(frac=1, random_state=random_state).reset_index(drop=True)

    val_df = val_nr.reset_index(drop=True)
    return train_df, val_df

# Call custom_split instead of direct train_test_split
train_data, val_data = custom_split(data, rare_labels, test_size=0.1, random_state=42) 

print(f"Training size: {len(train_data)}; Validation size: {len(val_data)}")

Vocabulary size: 39
Training size: 13972; Validation size: 1050


In [3]:
class BalineseDataset(Dataset):
    def __init__(self, df, images_dir, transform=None):
        self.data       = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.transform  = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name     = self.data.loc[idx, 'filename']
        label        = self.data.loc[idx, 'encoded_label']
        label_length = self.data.loc[idx, 'label_length']

        img_path = os.path.join(self.images_dir, img_name)
        image    = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long)
        return image, label, torch.tensor(label_length, dtype=torch.long)

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5)
    )
])

train_dataset = BalineseDataset(train_data, images_dir, transform=transform)
val_dataset   = BalineseDataset(val_data,   images_dir, transform=transform)

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [5]:
class ViTEncoder(nn.Module):
    """
    A simple ViT encoder that extracts patch embeddings as [batch_size, num_patches, hidden_dim].
    """
    def __init__(self, model_name="vit_large_patch16_224", pretrained=True):
        super(ViTEncoder, self).__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained)
        self.vit.head = nn.Identity()
        self.encoder_dim = self.vit.embed_dim

    def forward(self, x):
        """
        :param x: [batch_size, 3, 224, 224]
        :return:  [batch_size, num_patches, encoder_dim]
        """
        # forward_features usually returns [B, hidden_dim, H', W'] or [B, hidden_dim]
        feats = self.vit.forward_features(x)  # [B, hidden_dim, H', W'] for vit_large_patch16_224

        # Flatten the spatial dimensions
        if feats.dim() == 4:  # [B, C, H, W]
            b, c, h, w = feats.shape
            feats = feats.permute(0, 2, 3, 1).reshape(b, -1, c)  # [B, H*W, C]

        return feats



class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # transform encoder output
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # transform decoder hidden
        self.full_att    = nn.Linear(attention_dim, 1)
        self.relu        = nn.ReLU()
        self.softmax     = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out:    [batch_size, num_patches, encoder_dim]
        decoder_hidden: [batch_size, decoder_dim]
        """
        att1 = self.encoder_att(encoder_out)                  # [batch_size, num_patches, attention_dim]
        att2 = self.decoder_att(decoder_hidden).unsqueeze(1)  # [batch_size, 1, attention_dim]

        # sum -> relu -> full_att -> squeeze -> softmax
        att  = self.full_att(self.relu(att1 + att2)).squeeze(2)  # [batch_size, num_patches]
        alpha = self.softmax(att)
        # Weighted sum of the encoder_out
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # [batch_size, encoder_dim]
        return attention_weighted_encoding, alpha

class DecoderRNN(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=768, teacher_forcing_ratio=0.5):
        super(DecoderRNN, self).__init__()

        self.attention     = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding     = nn.Embedding(vocab_size, embed_dim)
        self.dropout       = nn.Dropout(p=0.5)
        self.lstm = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta  = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.init_weights()

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)  # [batch_size, encoder_dim]
        h = self.init_h(mean_encoder_out)         # [batch_size, decoder_dim]
        c = self.init_c(mean_encoder_out)         # [batch_size, decoder_dim]
        return (h, c)

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        encoder_out:      [batch_size, num_patches, encoder_dim]
        encoded_captions: [batch_size, max_label_length]
        caption_lengths:  [batch_size, 1]
        """
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out      = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        embeddings = self.embedding(encoded_captions)

        # Initialize hidden states
        h, c = self.init_hidden_state(encoder_out)

        decode_lengths    = (caption_lengths - 1).tolist()
        max_decode_length = max(decode_lengths)

        batch_size = encoder_out.size(0)
        vocab_size = self.fc.out_features

        predictions = torch.zeros(batch_size, max_decode_length, vocab_size, device=encoder_out.device)
        alphas      = torch.zeros(batch_size, max_decode_length, encoder_out.size(1), device=encoder_out.device)

        # We'll feed the first token from the input (<SOS>) or from the previous prediction
        prev_tokens = encoded_captions[:, 0].clone()

        for t in range(max_decode_length):
            batch_size_t = sum([l > t for l in decode_lengths])

            attention_weighted_encoding, alpha = self.attention(
                encoder_out[:batch_size_t],
                h[:batch_size_t]
            )

            # Apply gating
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
            attention_weighted_encoding = gate * attention_weighted_encoding

            # Teacher forcing?
            use_teacher_forcing = (torch.rand(1).item() < self.teacher_forcing_ratio)
            if use_teacher_forcing:
                current_input = embeddings[:batch_size_t, t, :]
            else:
                current_input = self.embedding(prev_tokens[:batch_size_t].detach())

            h_next, c_next = self.lstm(
                torch.cat([current_input, attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t])
            )

            preds = self.fc(self.dropout(h_next))
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :]      = alpha

            _, next_tokens = preds.max(dim=1)
            prev_tokens_ = prev_tokens.clone()
            prev_tokens_[:batch_size_t] = next_tokens.detach()
            prev_tokens = prev_tokens_

            h_new = torch.zeros_like(h)
            c_new = torch.zeros_like(c)

            h_new[:batch_size_t] = h_next
            c_new[:batch_size_t] = c_next

            h, c = h_new, c_new

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind


encoder = ViTEncoder(model_name="vit_large_patch16_224", pretrained=True)
decoder = DecoderRNN(
    attention_dim=256,
    embed_dim=256,
    decoder_dim=512,
    vocab_size=vocab_size,
    encoder_dim=encoder.encoder_dim,  # typically 768 for ViT-B/16
    teacher_forcing_ratio=0.5
)

encoder = encoder.to(device)
decoder = decoder.to(device)


criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
encoder_optimizer = optim.Adam(encoder.parameters(), lr=1e-4)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=4e-4)


class ImageCaptioningTrainer:
    def __init__(self, encoder, decoder, 
                 criterion, encoder_optimizer, decoder_optimizer, 
                 train_loader, val_loader, device, 
                 char_to_idx, idx_to_char, max_label_length):
        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)
        self.criterion = criterion
        self.encoder_optimizer = encoder_optimizer
        self.decoder_optimizer = decoder_optimizer
        self.train_loader = train_loader
        self.val_loader   = val_loader
        self.device = device
        self.char_to_idx = char_to_idx
        self.idx_to_char = idx_to_char
        self.max_label_length = max_label_length

    def fit(self, num_epochs, early_stopping_patience=5):
        """
        Train the model for 'num_epochs' epochs.
        If val_loss doesn't improve for 'early_stopping_patience' consecutive epochs,
        we stop training early.
        """
        best_val_loss = float("inf")  # track minimum validation loss seen
        no_improvement_epochs = 0     # epochs since last improvement

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            train_loss, train_cer = self.train_one_epoch()
            val_loss,   val_cer   = self.validate_one_epoch(top_n=5)

            print(f"[{epoch+1}/{num_epochs}] "
                  f"Train Loss: {train_loss:.4f}, Train CER: {train_cer:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val CER: {val_cer:.4f}")

            # Early Stopping Check
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                no_improvement_epochs = 0
            else:
                # No improvement this epoch
                no_improvement_epochs += 1
                print(f"No improvement for {no_improvement_epochs} epoch(s).")

                if no_improvement_epochs >= early_stopping_patience:
                    print(f"Early stopping triggered after {no_improvement_epochs} epochs "
                          f"without improvement on validation loss.")
                    break

    def train_one_epoch(self):
        self.encoder.train()
        self.decoder.train()
        running_loss = 0.0
        total_edit_distance = 0
        total_ref_length = 0

        for batch_idx, (images, labels, label_lengths) in enumerate(self.train_loader):
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            label_lengths = label_lengths.to(self.device, non_blocking=True)

            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            encoder_out = self.encoder(images)
            caption_lengths = torch.tensor(
                [self.max_label_length] * labels.size(0)
            ).unsqueeze(1).to(self.device)

            outputs, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
                encoder_out, labels, caption_lengths
            )

            # Targets = encoded captions without the <SOS>
            targets = encoded_captions[:, 1:]

            # Flatten for loss
            outputs_flat = outputs.view(-1, self.decoder.fc.out_features)
            targets_flat = targets.contiguous().view(-1)

            loss = self.criterion(outputs_flat, targets_flat)
            loss.backward()

            self.decoder_optimizer.step()
            self.encoder_optimizer.step()

            running_loss += loss.item()

            # Compute CER
            batch_size = labels.size(0)
            _, preds_flat = torch.max(outputs_flat, dim=1)
            preds_seq = preds_flat.view(batch_size, -1)

            for i in range(batch_size):
                pred_indices = preds_seq[i].detach().cpu().numpy()
                target_indices = targets[i].detach().cpu().numpy()

                mask = (target_indices != self.char_to_idx['<PAD>'])
                pred_indices   = pred_indices[mask]
                target_indices = target_indices[mask]

                pred_str   = ''.join([self.idx_to_char.get(idx, '') for idx in pred_indices])
                target_str = ''.join([self.idx_to_char.get(idx, '') for idx in target_indices])

                edit_dist = editdistance.eval(pred_str, target_str)
                total_edit_distance += edit_dist
                total_ref_length    += len(target_str)

            # Optional: Print intermediate training stats
            if (batch_idx + 1) % 50 == 0:
                print(f"  Batch {batch_idx + 1}/{len(self.train_loader)} - Loss: {loss.item():.4f}")

        avg_loss = running_loss / len(self.train_loader)
        avg_cer  = total_edit_distance / total_ref_length if total_ref_length > 0 else 0.0
        return avg_loss, avg_cer

    def validate_one_epoch(self, top_n=5):
        self.encoder.eval()
        self.decoder.eval()
        running_loss = 0.0
        total_edit_distance = 0
        total_ref_length = 0

        sample_cer_info = []

        with torch.no_grad():
            for batch_idx, (images, labels, label_lengths) in enumerate(self.val_loader):
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                label_lengths = label_lengths.to(self.device, non_blocking=True)

                encoder_out = self.encoder(images)
                caption_lengths = torch.tensor(
                    [self.max_label_length] * labels.size(0)
                ).unsqueeze(1).to(self.device)

                outputs, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
                    encoder_out, labels, caption_lengths
                )

                targets = encoded_captions[:, 1:]

                outputs_flat = outputs.view(-1, self.decoder.fc.out_features)
                targets_flat = targets.contiguous().view(-1)

                loss = self.criterion(outputs_flat, targets_flat)
                running_loss += loss.item()

                batch_size = labels.size(0)
                _, preds_flat = torch.max(outputs_flat, dim=1)
                preds_seq = preds_flat.view(batch_size, -1)

                for i in range(batch_size):
                    pred_indices = preds_seq[i].detach().cpu().numpy()
                    target_indices = targets[i].detach().cpu().numpy()

                    mask = (target_indices != self.char_to_idx['<PAD>'])
                    pred_indices   = pred_indices[mask]
                    target_indices = target_indices[mask]

                    pred_str   = ''.join(self.idx_to_char.get(idx, '') for idx in pred_indices)
                    target_str = ''.join(self.idx_to_char.get(idx, '') for idx in target_indices)

                    edit_dist = editdistance.eval(pred_str, target_str)
                    ref_len   = len(target_str)
                    cer       = edit_dist / ref_len if ref_len > 0 else 0

                    total_edit_distance += edit_dist
                    total_ref_length    += ref_len

                    sample_cer_info.append({
                        "pred": pred_str,
                        "gt": target_str,
                        "cer": cer
                    })

                    # Print a few samples from the first batch
                    # if batch_idx == 0 and i < 3:
                    #     print(f"Sample {i + 1}:")
                    #     print(f"Predicted: {pred_str}")
                    #     print(f"Target   : {target_str}\n")

        avg_loss = running_loss / len(self.val_loader)
        avg_cer  = total_edit_distance / total_ref_length if total_ref_length > 0 else 0.0

        # Sort by CER descending
        sample_cer_info.sort(key=lambda x: x["cer"], reverse=True)
        worst_samples = sample_cer_info[:top_n]

        print(f"\n=== Top {top_n} Worst Samples by CER ===")
        for idx, sample in enumerate(worst_samples):
            print(f"[{idx+1}] CER: {sample['cer']:.3f}")
            print(f"   Predicted: {sample['pred']}")
            print(f"   Ground Truth: {sample['gt']}\n")

        return avg_loss, avg_cer


trainer = ImageCaptioningTrainer(
    encoder=encoder,
    decoder=decoder,
    criterion=criterion,
    encoder_optimizer=encoder_optimizer,
    decoder_optimizer=decoder_optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length
)

num_epochs = 30
trainer.fit(num_epochs)


Epoch 1/30
  Batch 50/437 - Loss: 2.4494
  Batch 100/437 - Loss: 2.3068
  Batch 150/437 - Loss: 2.0392
  Batch 200/437 - Loss: 1.8417
  Batch 250/437 - Loss: 1.7403
  Batch 300/437 - Loss: 1.7022
  Batch 350/437 - Loss: 0.9931
  Batch 400/437 - Loss: 1.3448
Sample 1:
Predicted: ,<EOS>
Target   : ,<EOS>

Sample 2:
Predicted: suura<EOS>
Target   : swaha<EOS>

Sample 3:
Predicted: suaha<EOS>
Target   : swaha<EOS>


=== Top 5 Worst Samples by CER ===
[1] CER: 1.778
   Predicted: .<EOS><EOS><EOS><EOS>
   Ground Truth: ring<EOS>

[2] CER: 1.700
   Predicted: aa<EOS><EOS><EOS><EOS>
   Ground Truth: irung<EOS>

[3] CER: 1.500
   Predicted: aasa<EOS><EOS><EOS><EOS>
   Ground Truth: tanhana<EOS>

[4] CER: 1.471
   Predicted: pwiiiign<EOS><EOS><EOS><EOS><EOS>
   Ground Truth: awighnamastu<EOS>

[5] CER: 1.300
   Predicted: a<EOS><EOS>aa<EOS>
   Ground Truth: panti<EOS>

[1/30] Train Loss: 1.8378, Train CER: 0.6011 | Val Loss: 0.8633, Val CER: 0.2567

Epoch 2/30
  Batch 50/437 - Loss: 0.7274
  Ba

In [6]:
import test_balinese_model
print(dir(test_balinese_model))
from test_balinese_model import evaluate_test_set

a
['BalineseDataset', 'DataLoader', 'Dataset', 'Image', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'editdistance', 'evaluate_test_set', 'np', 'os', 'pd', 'torch', 'transforms']


In [7]:
test_ground_truth_path = os.path.join(os.getcwd(), 'balinese_transliteration_test.txt')
test_images_dir        = os.path.join(os.getcwd(), 'balinese_word_test')

In [8]:
test_cer = evaluate_test_set(
    encoder=encoder,
    decoder=decoder,
    device=device,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    max_label_length=max_label_length,
    test_ground_truth_path=test_ground_truth_path,
    test_images_dir=test_images_dir
)

print(f"Final Test CER: {test_cer:.4f}")

Unknown characters in test labels: set()

=== Sample predictions (first 5) ===
Image: test1.png
Predicted: ,
Ground Truth: ,

Image: test2.png
Predicted: biak
Ground Truth: biakta

Image: test3.png
Predicted: antah
Ground Truth: ngantah

Image: test4.png
Predicted: sarina
Ground Truth: sarira

Image: test5.png
Predicted: yu
Ground Truth: yu

Global CER on test set: 0.1409

Top 5 highest CER results:
1) Image: test10070.png
   CER: 8.0000
   Predicted       : .patatnia
   Ground Truth    : .

2) Image: test1330.png
   CER: 7.0000
   Predicted       : santiaa
   Ground Truth    : .

3) Image: test580.png
   CER: 6.0000
   Predicted       : .aattia
   Ground Truth    : .

4) Image: test1335.png
   CER: 5.0000
   Predicted       : panti
   Ground Truth    : .

5) Image: test6297.png
   CER: 5.0000
   Predicted       : panti
   Ground Truth    : .

Final Test CER: 0.1409


In [9]:
# test_ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_test.txt')
# test_images_dir        = os.path.join(base_dir, 'balinese_word_test')

# test_filenames = []
# test_labels    = []

# with open(test_ground_truth_path, 'r', encoding='utf-8') as file:
#     for line in file:
#         line = line.strip()
#         if line:
#             parts = line.split(';')
#             if len(parts) == 2:
#                 filename, label = parts
#                 label = label.lower()
#                 test_filenames.append(filename)
#                 test_labels.append(label)
#             else:
#                 print(f"Skipping malformed line: {line}")

# test_data = pd.DataFrame({
#     'filename': test_filenames,
#     'label': test_labels
# })

# # Check for unknown chars in test set
# test_chars = set(''.join(test_data['label']))
# unknown_chars = test_chars - set(char_to_idx.keys())
# print(f"Unknown characters in test labels: {unknown_chars}")

# # Encode test labels
# max_label_length_test = max(len(lbl) for lbl in test_data['label']) + 2
# def encode_label_test(label, char_to_idx, max_length):
#     encoded = (
#         [char_to_idx['<SOS>']] +
#         [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] +
#         [char_to_idx['<EOS>']]
#     )
#     if len(encoded) > max_length:
#         encoded = encoded[:max_length]
#     else:
#         encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
#     return encoded

# test_data['encoded_label'] = test_data['label'].apply(lambda x: encode_label_test(x, char_to_idx, max_label_length_test))
# test_data['label_length']  = test_data['label'].apply(len)

# test_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=(0.5, 0.5, 0.5),
#         std=(0.5, 0.5, 0.5)
#     )
# ])

# test_dataset = BalineseDataset(test_data, test_images_dir, transform=test_transform)
# test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)


# def inference(encoder, decoder, data_loader, device, char_to_idx, idx_to_char, max_seq_length, test_data):
#     encoder.eval()
#     decoder.eval()

#     eos_idx = char_to_idx['<EOS>']
#     results = []

#     with torch.no_grad():
#         for batch_idx, (images, labels, label_lengths) in enumerate(data_loader):
#             images = images.to(device, non_blocking=True)
#             labels = labels.to(device, non_blocking=True)

#             batch_size  = images.size(0)
#             encoder_out = encoder(images)  # [B, num_patches, encoder_dim]

#             # Init LSTM state
#             h, c = decoder.init_hidden_state(encoder_out)


#             # Start tokens
#             inputs = torch.full(
#                 (batch_size,),
#                 fill_value=char_to_idx['<SOS>'],
#                 dtype=torch.long,
#                 device=device
#             )

#             all_preds = []

#             for _ in range(max_seq_length):
#                 embeddings = decoder.embedding(inputs)  # [batch_size, embed_dim]
            
#                 # Attention
#                 attention_weighted_encoding, alpha = decoder.attention(encoder_out, h)
            
#                 # Gating
#                 gate = decoder.sigmoid(decoder.f_beta(h))
#                 attention_weighted_encoding = gate * attention_weighted_encoding
            
#                 h, c = decoder.lstm(
#                     torch.cat([embeddings, attention_weighted_encoding], dim=1),
#                     (h, c))
            
#                 # Predict next token
#                 preds = decoder.fc(decoder.dropout(h))  # [batch_size, vocab_size]
#                 _, preds_idx = preds.max(dim=1)
            
#                 all_preds.append(preds_idx.cpu().numpy())
#                 inputs = preds_idx


#             # Convert shape from [max_seq_length, batch_size] -> [batch_size, max_seq_length]
#             all_preds = np.array(all_preds).T

#             for i in range(batch_size):
#                 pred_indices = all_preds[i]

#                 # Stop at <EOS> if present
#                 if eos_idx in pred_indices:
#                     first_eos = np.where(pred_indices == eos_idx)[0][0]
#                     pred_indices = pred_indices[:first_eos]

#                 pred_chars = [idx_to_char.get(idx, '') for idx in pred_indices]
#                 pred_str   = ''.join(pred_chars)

#                 # Ground truth
#                 label_indices = labels[i].detach().cpu().numpy()
#                 # remove <SOS>
#                 label_indices = label_indices[1:]
#                 if eos_idx in label_indices:
#                     eos_pos = np.where(label_indices == eos_idx)[0][0]
#                     label_indices = label_indices[:eos_pos]
#                 else:
#                     label_indices = label_indices[label_indices != char_to_idx['<PAD>']]

#                 label_chars = [idx_to_char.get(idx, '') for idx in label_indices]
#                 label_str   = ''.join(label_chars)

#                 global_idx = batch_idx * batch_size + i
#                 image_filename = test_data.iloc[global_idx]['filename']

#                 results.append({
#                     'image_filename': image_filename,
#                     'predicted_caption': pred_str,
#                     'ground_truth_caption': label_str
#                 })

#     return results

# test_results = inference(
#     encoder=encoder,
#     decoder=decoder,
#     data_loader=test_loader,
#     device=device,
#     char_to_idx=char_to_idx,
#     idx_to_char=idx_to_char,
#     max_seq_length=max_label_length_test,
#     test_data=test_data
# )

# # Print first few
# for r in test_results[:5]:
#     print("Image:", r['image_filename'])
#     print("Predicted:", r['predicted_caption'])
#     print("Ground Truth:", r['ground_truth_caption'])
#     print()


# def calculate_global_cer(results):
#     total_ed   = 0
#     total_refs = 0
#     for r in results:
#         ref = r['ground_truth_caption']
#         hyp = r['predicted_caption']
#         dist = editdistance.eval(ref, hyp)
#         total_ed   += dist
#         total_refs += len(ref)
#     if total_refs == 0:
#         return 0.0
#     return total_ed / total_refs

# global_cer = calculate_global_cer(test_results)
# print(f"Global CER on test set: {global_cer:.4f}")


In [10]:
# n = 5

# results_with_cer = []
# for r in test_results:
#     ref = r['ground_truth_caption']
#     hyp = r['predicted_caption']
#     dist = editdistance.eval(ref, hyp)
#     length = len(ref)
#     cer = dist / length if length > 0 else 0
#     new_r = r.copy()
#     new_r['cer'] = cer
#     results_with_cer.append(new_r)

# # 2. Sort by CER in descending order
# results_with_cer.sort(key=lambda x: x['cer'], reverse=True)

# # 3. Print the top N highest CER
# print(f"\nTop {n} highest CER results:")
# for i, r in enumerate(results_with_cer[:n], start=1):
#     print(f"{i}) Image: {r['image_filename']}")
#     print(f"   CER: {r['cer']:.4f}")
#     print(f"   Predicted       : {r['predicted_caption']}")
#     print(f"   Ground Truth    : {r['ground_truth_caption']}")
#     print()