In [None]:
import os
import pandas as pd
import numpy as np
import editdistance
import time 


from PIL import Image
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm  

from sklearn.model_selection import train_test_split

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_dir = os.getcwd()

# Same paths as your original code
ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_train.txt') 
images_dir        = os.path.join(base_dir, 'balinese_word_train')

filenames = []
labels    = []

with open(ground_truth_path, 'r', encoding='utf-8') as file:
    for line in file:
        line = line.strip()
        if line:  # Ensure the line is not empty
            parts = line.split(';')
            if len(parts) == 2:
                filename, label = parts
                label = label.lower()
                filenames.append(filename)
                labels.append(label)
            else:
                print(f"Skipping malformed line: {line}")

data = pd.DataFrame({
    'filename': filenames,
    'label': labels
})

label_counts = data['label'].value_counts()

all_text = ''.join(data['label'])
unique_chars = sorted(list(set(all_text)))

# Create character->index starting from 1
char_to_idx = {char: idx + 1 for idx, char in enumerate(unique_chars)}
# Add special tokens
char_to_idx['<PAD>'] = 0
char_to_idx['<UNK>'] = len(char_to_idx)
char_to_idx['<SOS>'] = len(char_to_idx)
char_to_idx['<EOS>'] = len(char_to_idx)

# Reverse mapping
idx_to_char = {v: k for k, v in char_to_idx.items()}

vocab_size = len(char_to_idx)
print(f"Vocabulary size: {vocab_size}")

def encode_label(label, char_to_idx, max_length):
    """
    Converts a label (string) into a list of indices with <SOS>, <EOS>, padding, etc.
    """
    encoded = (
        [char_to_idx['<SOS>']] +
        [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] +
        [char_to_idx['<EOS>']]
    )
    # Pad if needed
    if len(encoded) < max_length:
        encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
    else:
        encoded = encoded[:max_length]
    return encoded

max_label_length = max(len(label) for label in data['label']) + 2  # +2 for <SOS> and <EOS>
data['encoded_label'] = data['label'].apply(lambda x: encode_label(x, char_to_idx, max_label_length))
data['label_length']  = data['label'].apply(len)

rare_labels = label_counts[label_counts < 3].index  # NEW: words that appear <3 times

def custom_split(df, rare_label_list, test_size=0.1, random_state=42):
    # Separate rare words from frequent ones
    rare_df     = df[df['label'].isin(rare_label_list)]
    non_rare_df = df[~df['label'].isin(rare_label_list)]

    #  train/val split for non-rare
    train_nr, val_nr = train_test_split(non_rare_df, test_size=test_size, 
                                        random_state=random_state)

    # Combine rare samples entirely into training
    train_df = pd.concat([train_nr, rare_df], ignore_index=True)
    # Shuffle after combining
    train_df = train_df.sample(frac=1, random_state=random_state).reset_index(drop=True)

    val_df = val_nr.reset_index(drop=True)
    return train_df, val_df

# Call custom_split instead of direct train_test_split
train_data, val_data = custom_split(data, rare_labels, test_size=0.1, random_state=42) 

print(f"Training size: {len(train_data)}; Validation size: {len(val_data)}")
Vocabulary size: 39
Training size: 13972; Validation size: 1050
class BalineseDataset(Dataset):
    def __init__(self, df, images_dir, transform=None):
        self.data       = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.transform  = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name     = self.data.loc[idx, 'filename']
        label        = self.data.loc[idx, 'encoded_label']
        label_length = self.data.loc[idx, 'label_length']

        img_path = os.path.join(self.images_dir, img_name)
        image    = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long)
        return image, label, torch.tensor(label_length, dtype=torch.long)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5)
    )
])

train_dataset = BalineseDataset(train_data, images_dir, transform=transform)
val_dataset   = BalineseDataset(val_data,   images_dir, transform=transform)

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
class ImageCaptioningTrainer:
    def __init__(self, encoder, decoder, 
                 criterion, encoder_optimizer, decoder_optimizer, 
                 train_loader, val_loader, device, 
                 char_to_idx, idx_to_char, max_label_length,
                 model_name, csv_filename="training_results.csv"):
        self.encoder = encoder.to(device)
        self.decoder             = decoder.to(device)
        self.criterion           = criterion
        self.encoder_optimizer   = encoder_optimizer
        self.decoder_optimizer   = decoder_optimizer
        self.train_loader        = train_loader
        self.val_loader          = val_loader
        self.device              = device
        self.char_to_idx         = char_to_idx
        self.idx_to_char         = idx_to_char
        self.max_label_length    = max_label_length
        self.model_name = model_name
        self.csv_filename = csv_filename

        self.train_losses = []
        self.val_losses   = []

        self.train_cers   = []
        self.val_cers     = []

    def fit(self, num_epochs):
        start_time = time.time()

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            train_loss, train_cer = self.train_one_epoch()
            val_loss,   val_cer   = self.validate_one_epoch(top_n=5)

            print(f"[{epoch+1}/{num_epochs}] "
                  f"Train Loss: {train_loss:.4f}, Train CER: {train_cer:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val CER: {val_cer:.4f}")

            # Store epoch results
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_cers.append(train_cer)
            self.val_cers.append(val_cer)
            
        
        # Calculate total training time
        end_time = time.time() 
        total_time = end_time - start_time
        hours = int(total_time // 3600)
        minutes = int((total_time % 3600) // 60)

        print(f"\nTraining completed in {hours}h {minutes}m.")

        num_epochs = len(self.train_losses)
        epoch_cols = [f"epoch{i+1}" for i in range(num_epochs)]

        # Create the new data block to insert
        new_rows = pd.DataFrame([
            [self.model_name, "training loss"] + self.train_losses,
            [self.model_name, "validation loss"] + self.val_losses,
            [self.model_name, "training cer"] + self.train_cers,
            [self.model_name, "validation cer"] + self.val_cers
        ], columns=["model_name", "mode"] + epoch_cols)
        
        # Check if CSV already exists
        if os.path.exists(self.csv_filename):
            df_existing = pd.read_csv(self.csv_filename)
            df_existing = df_existing[df_existing["model_name"] != self.model_name]
            df_updated = pd.concat([df_existing, new_rows], ignore_index=True)
        else:
            df_updated = new_rows
        
        # Save the updated CSV
        df_updated.to_csv(self.csv_filename, index=False)
        print(f"\nResults have been written to: {self.csv_filename}")


        # Save model weights
        # torch.save(self.encoder.state_dict(), f"encoder_{self.model_name}.pth")
        # torch.save(self.decoder.state_dict(), f"decoder_{self.model_name}.pth")
        # print(f"Encoder and decoder models saved: encoder_{self.model_name}.pth, decoder_{self.model_name}.pth")
        

    def train_one_epoch(self):
        self.encoder.train()
        self.decoder.train()
        running_loss           = 0.0
        total_edit_distance    = 0
        total_ref_length       = 0

        for batch_idx, (images, labels, label_lengths) in enumerate(self.train_loader):
            images        = images.to(self.device, non_blocking=True)
            labels        = labels.to(self.device, non_blocking=True)
            label_lengths = label_lengths.to(self.device, non_blocking=True)

            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            encoder_out   = self.encoder(images)
            caption_lengths = torch.tensor(
                [self.max_label_length] * labels.size(0)
            ).unsqueeze(1).to(self.device)

            outputs, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
                encoder_out, labels, caption_lengths
            )

            # Targets = encoded captions without the <SOS>
            targets = encoded_captions[:, 1:]

            # Flatten for loss
            outputs_flat = outputs.view(-1, self.decoder.fc.out_features)
            targets_flat = targets.contiguous().view(-1)

            loss = self.criterion(outputs_flat, targets_flat)
            loss.backward()

            self.decoder_optimizer.step()
            self.encoder_optimizer.step()

            running_loss += loss.item()

            # Compute CER for the batch (global style)
            batch_size = labels.size(0)
            _, preds_flat = torch.max(outputs_flat, dim=1)
            preds_seq = preds_flat.view(batch_size, -1)

            for i in range(batch_size):
                pred_indices   = preds_seq[i].detach().cpu().numpy()
                target_indices = targets[i].detach().cpu().numpy()

                mask          = (target_indices != self.char_to_idx['<PAD>'])
                pred_indices  = pred_indices[mask]
                target_indices= target_indices[mask]

                pred_chars    = [self.idx_to_char.get(idx, '') for idx in pred_indices]
                target_chars  = [self.idx_to_char.get(idx, '') for idx in target_indices]
                pred_str      = ''.join(pred_chars)
                target_str    = ''.join(target_chars)

                edit_dist           = editdistance.eval(pred_str, target_str)
                total_edit_distance += edit_dist
                total_ref_length    += len(target_str)

            # if (batch_idx + 1) % 50 == 0:
            #     print(f'Batch {batch_idx + 1}/{len(self.train_loader)} - Loss: {loss.item():.4f}')

        avg_loss = running_loss / len(self.train_loader)
        avg_cer  = total_edit_distance / total_ref_length if total_ref_length > 0 else 0.0
        return avg_loss, avg_cer

    def validate_one_epoch(self, top_n=5):
        self.encoder.eval()
        self.decoder.eval()
        running_loss         = 0.0
        total_edit_distance  = 0
        total_ref_length     = 0

        # each sample’s CER
        sample_cer_info = []

        with torch.no_grad():
            for batch_idx, (images, labels, label_lengths) in enumerate(self.val_loader):
                images        = images.to(self.device, non_blocking=True)
                labels        = labels.to(self.device, non_blocking=True)
                label_lengths = label_lengths.to(self.device, non_blocking=True)

                encoder_out = self.encoder(images)
                caption_lengths = torch.tensor(
                    [self.max_label_length] * labels.size(0)
                ).unsqueeze(1).to(self.device)

                outputs, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
                    encoder_out, labels, caption_lengths
                )
                targets = encoded_captions[:, 1:]

                outputs_flat = outputs.view(-1, self.decoder.fc.out_features)
                targets_flat = targets.contiguous().view(-1)

                loss = self.criterion(outputs_flat, targets_flat)
                running_loss += loss.item()

                batch_size = labels.size(0)
                _, preds_flat = torch.max(outputs_flat, dim=1)
                preds_seq = preds_flat.view(batch_size, -1)

                for i in range(batch_size):
                    pred_indices   = preds_seq[i].detach().cpu().numpy()
                    target_indices = targets[i].detach().cpu().numpy()

                    mask           = (target_indices != self.char_to_idx['<PAD>'])
                    pred_indices   = pred_indices[mask]
                    target_indices = target_indices[mask]

                    pred_chars   = [self.idx_to_char.get(idx, '') for idx in pred_indices]
                    target_chars = [self.idx_to_char.get(idx, '') for idx in target_indices]
                    pred_str     = ''.join(pred_chars)
                    target_str   = ''.join(target_chars)

                    edit_dist = editdistance.eval(pred_str, target_str)
                    ref_len   = len(target_str)
                    cer       = edit_dist / ref_len if ref_len > 0 else 0
    
                    total_edit_distance += edit_dist
                    total_ref_length    += ref_len
    
                    # Store sample info
                    # sample_cer_info.append({
                    #     "pred": pred_str,
                    #     "gt": target_str,
                    #     "cer": cer
                    # })

                    # Print a few samples from the 1st batch
                    # if batch_idx == 0 and i < 3:
                    #     print(f"Sample {i + 1}:")
                    #     print(f"Predicted: {pred_str}")
                    #     print(f"Target   : {target_str}\n")

        avg_loss = running_loss / len(self.val_loader)
        avg_cer  = total_edit_distance / total_ref_length if total_ref_length > 0 else 0.0

        # Sort by CER descending
        sample_cer_info.sort(key=lambda x: x["cer"], reverse=True)
        # Take top_n
        worst_samples = sample_cer_info[:top_n]
    
        # print(f"\n=== Top {top_n} Worst Samples by CER ===")
        # for idx, sample in enumerate(worst_samples):
        #     print(f"[{idx+1}] CER: {sample['cer']:.3f}")
        #     print(f"   Predicted: {sample['pred']}")
        #     print(f"   Ground Truth: {sample['gt']}\n")
       
        return avg_loss, avg_cer

In [None]:
test_ground_truth_path = os.path.join(base_dir, 'balinese_transliteration_test.txt')
test_images_dir        = os.path.join(base_dir, 'balinese_word_test')

test_filenames = []
test_labels    = []

with open(test_ground_truth_path, 'r', encoding='utf-8') as file:
    for line in file:
        line = line.strip()
        if line:
            parts = line.split(';')
            if len(parts) == 2:
                filename, label = parts
                label = label.lower()
                test_filenames.append(filename)
                test_labels.append(label)
            else:
                print(f"Skipping malformed line: {line}")

test_data = pd.DataFrame({
    'filename': test_filenames,
    'label': test_labels
})

# Check for unknown chars in test set
test_chars = set(''.join(test_data['label']))
unknown_chars = test_chars - set(char_to_idx.keys())
print(f"Unknown characters in test labels: {unknown_chars}")

# Encode test labels
max_label_length_test = max(len(lbl) for lbl in test_data['label']) + 2
def encode_label_test(label, char_to_idx, max_length):
    encoded = (
        [char_to_idx['<SOS>']] +
        [char_to_idx.get(ch, char_to_idx['<UNK>']) for ch in label] +
        [char_to_idx['<EOS>']]
    )
    if len(encoded) > max_length:
        encoded = encoded[:max_length]
    else:
        encoded += [char_to_idx['<PAD>']] * (max_length - len(encoded))
    return encoded

test_data['encoded_label'] = test_data['label'].apply(lambda x: encode_label_test(x, char_to_idx, max_label_length_test))
test_data['label_length']  = test_data['label'].apply(len)

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5)
    )
])

test_dataset = BalineseDataset(test_data, test_images_dir, transform=test_transform)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
Unknown characters in test labels: set()
def inference(encoder, decoder, data_loader, device, char_to_idx, idx_to_char, max_seq_length, test_data):
    encoder.eval()
    decoder.eval()

    eos_idx = char_to_idx['<EOS>']
    results = []

    with torch.no_grad():
        for batch_idx, (images, labels, label_lengths) in enumerate(data_loader):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            batch_size  = images.size(0)
            encoder_out = encoder(images)  # [B, num_patches, encoder_dim]

            # Init LSTM state
            h1, c1, h2, c2 = decoder.init_hidden_state(encoder_out)

            # Start tokens (all <SOS>)
            inputs = torch.full(
                (batch_size,),
                fill_value=char_to_idx['<SOS>'],
                dtype=torch.long,
                device=device
            )

            all_preds = []

            for _ in range(max_seq_length):
                # Embedding
                embeddings = decoder.embedding(inputs)

                # Attention
                attention_weighted_encoding, _ = decoder.attention(encoder_out, h1)

                # Gating
                gate = decoder.sigmoid(decoder.f_beta(h1))
                attention_weighted_encoding = gate * attention_weighted_encoding

                # Pass through LSTM layers
                h1, c1 = decoder.lstm1(
                    torch.cat([embeddings, attention_weighted_encoding], dim=1),
                    (h1, c1)
                )
                h2, c2 = decoder.lstm2(h1, (h2, c2))

                # Get predicted token
                preds = decoder.fc(decoder.dropout(h2))  # [batch_size, vocab_size]
                _, preds_idx = preds.max(dim=1)

                # Feed next token
                all_preds.append(preds_idx.cpu().numpy())
                inputs = preds_idx

            # Reformat predictions to [batch_size, max_seq_length]
            all_preds = np.array(all_preds).T

            for i in range(batch_size):
                pred_indices = all_preds[i]

                # Stop at <EOS> if present
                if eos_idx in pred_indices:
                    first_eos = np.where(pred_indices == eos_idx)[0][0]
                    pred_indices = pred_indices[:first_eos]

                # Convert token indices -> string
                pred_chars = [idx_to_char.get(idx, '') for idx in pred_indices]
                pred_str   = ''.join(pred_chars)

                # Process ground truth
                label_indices = labels[i].cpu().numpy()
                # remove <SOS>
                label_indices = label_indices[1:]

                if eos_idx in label_indices:
                    eos_pos = np.where(label_indices == eos_idx)[0][0]
                    label_indices = label_indices[:eos_pos]
                else:
                    # remove <PAD> if present
                    label_indices = label_indices[label_indices != char_to_idx['<PAD>']]

                label_chars = [idx_to_char.get(idx, '') for idx in label_indices]
                label_str   = ''.join(label_chars)

                global_idx    = batch_idx * batch_size + i
                image_filename= test_data.iloc[global_idx]['filename']

                results.append({
                    'image_filename': image_filename,
                    'predicted_caption': pred_str,
                    'ground_truth_caption': label_str
                })

    return results
def calculate_global_cer(results):
    total_ed   = 0
    total_refs = 0
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        dist = editdistance.eval(ref, hyp)
        total_ed   += dist
        total_refs += len(ref)
    if total_refs == 0:
        return 0.0
    return total_ed / total_refs
def print_top_worst_samples(results, n=5):
    # Calculate CER for each sample
    results_with_cer = []
    for r in results:
        ref = r['ground_truth_caption']
        hyp = r['predicted_caption']
        dist = editdistance.eval(ref, hyp)
        length = len(ref)
        cer = dist / length if length > 0 else 0
        # Copy the record and add cer
        new_r = r.copy()
        new_r['cer'] = cer
        results_with_cer.append(new_r)

    # Sort by CER (descending) and take the top N
    results_with_cer.sort(key=lambda x: x['cer'], reverse=True)
    worst_samples = results_with_cer[:n]

    print(f"\n=== Top {n} Worst Samples by CER ===")
    for i, sample in enumerate(worst_samples, start=1):
        print(f"{i}) Image: {sample['image_filename']}")
        print(f"   CER: {sample['cer']:.4f}")
        print(f"   Predicted       : {sample['predicted_caption']}")
        print(f"   Ground Truth    : {sample['ground_truth_caption']}")
        print()
training_csv = "training_results.csv"
if not os.path.exists(training_csv) or os.path.getsize(training_csv) == 0:
    pd.DataFrame(columns=["model_name", "mode", "epoch1", "epoch2"]).to_csv(training_csv, index=False)

csv_file = "test_cer_results.csv"
if not os.path.exists(csv_file) or os.path.getsize(csv_file) == 0:
    pd.DataFrame(columns=["model_name", "test_cer"]).to_csv(csv_file, index=False)

def log_test_cer(model_name, cer_value):
    df = pd.read_csv(csv_file)
    # Check if model_name exists
    if model_name in df['model_name'].values:
        # Update existing row
        df.loc[df['model_name'] == model_name, 'test_cer'] = cer_value
    else:
        # Add new row - use concat instead of append
        new_row = pd.DataFrame({"model_name": [model_name], "test_cer": [cer_value]})
        df = pd.concat([df, new_row], ignore_index=True)
    
    df.to_csv(csv_file, index=False)
    print(f"Logged {model_name}: {cer_value:.4f}")