In [1]:
import math
import os
import time
from collections import Counter

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

In [2]:
CSV_PATH = "../Cleaned_data/transcription_to_hieroglyphs.csv"
BATCH_SIZE = 16
EPOCHS = 5
LR = 1e-4

D_MODEL = 256
NHEAD = 8
NUM_LAYERS = 3
DROPOUT = 0.1
MAX_POSITIONS = 4096

CHECKPOINT_DIR = "../best_weights/transcription_to_hieroglyphs_transformer_checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

print("Device:", DEVICE)
print("Checkpoint dir:", CHECKPOINT_DIR)


Device: cuda
Checkpoint dir: ../best_weights/transcription_to_hieroglyphs_transformer_checkpoints


In [3]:
df = pd.read_csv(CSV_PATH)

df = df.dropna(subset=["transcription", "hieroglyphs"])
df["transcription"] = df["transcription"].astype(str)
df["hieroglyphs"] = df["hieroglyphs"].astype(str)

train_df, val_df = train_test_split(
    df,
    test_size=0.10,
    random_state=42,
    shuffle=True
)

print("Train size:", len(train_df))
print("Val size:", len(val_df))
train_df.head()

Train size: 31726
Val size: 3526


Unnamed: 0,transcription,hieroglyphs
27276,ḏi̯.t = n =k ꜥš 2 _ḫr 12,D37 X1 N35 V31 D36 N37 Z7 W22 Z7 Z1 Z1 Aa1 D21...
10081,"m bẖs,w ḥr ḥḏ,w","G17 D58 F32 O34 Z7 X4 D2 Z1 T3 I10 ""⸮"" G43 """" ..."
33867,"ꜥnḫ jt =j Rꜥw-Ḥr,w-ꜣḫ,tj-ḥꜥi̯-m-ꜣḫ,t M-rn≡f-m-...",S34 M17 X1 I9 A40 A40 < S34 G9 N27 N27 V28 D36...
11326,"ꜥḥꜥ.n Pꜣ-Rꜥ-Ḥr,w-ꜣḫ,tj.du ḥr ḏd n Ḏḥw,tj",P6 D36 N35 G41 G1 N5 G7 G5 G7 N27 X1 Z4 O1 O1 ...
3451,"snd,w ḥr mꜣj ḥnꜥ ꜥꜣm.w.pl rḫ =f sw r =f jw =f ...",S29 N35 D46 Z7 G54 A2 Z3A D2 Z1 U1 G1 M17 Z7 F...


In [4]:
tokenizer = Tokenizer(BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(
    vocab_size=5000,
    special_tokens=["<pad>", "<sos>", "<eos>", "<unk>"]
)

tokenizer.train_from_iterator(
    train_df["transcription"].tolist() + train_df["hieroglyphs"].tolist(),
    trainer
)

PAD = tokenizer.token_to_id("<pad>")
SOS = tokenizer.token_to_id("<sos>")
EOS = tokenizer.token_to_id("<eos>")
UNK = tokenizer.token_to_id("<unk>")

VOCAB_SIZE = tokenizer.get_vocab_size()

print("Vocab size:", VOCAB_SIZE)
print("Special IDs:", {"PAD": PAD, "SOS": SOS, "EOS": EOS, "UNK": UNK})

# Save tokenizer once
tokenizer_path = os.path.join(CHECKPOINT_DIR, "tokenizer.json")
tokenizer.save(tokenizer_path)
print("Saved tokenizer to:", tokenizer_path)

Vocab size: 5000
Special IDs: {'PAD': 0, 'SOS': 1, 'EOS': 2, 'UNK': 3}
Saved tokenizer to: ../best_weights/transcription_to_hieroglyphs_transformer_checkpoints\tokenizer.json


In [5]:
class Seq2SeqDataset(Dataset):
    def __init__(self, df, tokenizer, sos_id, eos_id):
        self.src = df["transcription"].tolist()
        self.tgt = df["hieroglyphs"].tolist()
        self.tokenizer = tokenizer
        self.sos = sos_id
        self.eos = eos_id

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src_ids = [self.sos] + self.tokenizer.encode(self.src[idx]).ids + [self.eos]
        tgt_ids = [self.sos] + self.tokenizer.encode(self.tgt[idx]).ids + [self.eos]
        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(tgt_ids, dtype=torch.long)


def collate_fn(batch, pad_id):
    src, tgt = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=pad_id)
    tgt = nn.utils.rnn.pad_sequence(tgt, batch_first=True, padding_value=pad_id)
    return src, tgt

In [6]:
train_ds = Seq2SeqDataset(train_df, tokenizer, SOS, EOS)
val_ds   = Seq2SeqDataset(val_df, tokenizer, SOS, EOS)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, PAD)
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, PAD)
)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)

        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class TransformerSeq2Seq(nn.Module):
    def __init__(self, vocab_size, pad_id):
        super().__init__()
        self.pad_id = pad_id

        self.src_emb = nn.Embedding(vocab_size, D_MODEL, padding_idx=pad_id)
        self.tgt_emb = nn.Embedding(vocab_size, D_MODEL, padding_idx=pad_id)

        self.pos_enc = PositionalEncoding(D_MODEL, DROPOUT, MAX_POSITIONS)

        self.transformer = nn.Transformer(
            d_model=D_MODEL,
            nhead=NHEAD,
            num_encoder_layers=NUM_LAYERS,
            num_decoder_layers=NUM_LAYERS,
            dropout=DROPOUT,
            batch_first=True
        )

        self.fc_out = nn.Linear(D_MODEL, vocab_size)

    def forward(self, src, tgt):
        src_pad_mask = (src == self.pad_id)
        tgt_pad_mask = (tgt == self.pad_id)

        src = self.pos_enc(self.src_emb(src) * math.sqrt(D_MODEL))
        tgt = self.pos_enc(self.tgt_emb(tgt) * math.sqrt(D_MODEL))

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(src.device)

        out = self.transformer(
            src, tgt,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
            memory_key_padding_mask=src_pad_mask
        )

        return self.fc_out(out)  # (B,T,V)

In [9]:
def batch_decode_ids(batch_ids, tokenizer, skip_special_tokens=True):
    # batch_ids: (B, T) int64 numpy or list of lists
    texts = []
    for ids in batch_ids:
        # tokenizers.Tokenizer supports decode(ids, skip_special_tokens=...)
        texts.append(tokenizer.decode(list(map(int, ids)), skip_special_tokens=skip_special_tokens))
    return texts


def token_f1_order_free(pred_text, gold_text):
    # Same logic as your compute_metrics: whitespace split + Counter overlap
    pred_tokens = pred_text.strip().split()
    gold_tokens = gold_text.strip().split()

    p_cnt = Counter(pred_tokens)
    g_cnt = Counter(gold_tokens)

    tp = sum((p_cnt & g_cnt).values())
    precision = tp / max(1, len(pred_tokens))
    recall = tp / max(1, len(gold_tokens))
    if precision + recall == 0:
        return 0.0
    return (2 * precision * recall) / (precision + recall)

In [10]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0

    for src, tgt in loader:
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)

        tgt_in = tgt[:, :-1]
        tgt_y  = tgt[:, 1:]

        optimizer.zero_grad()
        logits = model(src, tgt_in)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), tgt_y.reshape(-1))

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / max(1, len(loader))


@torch.no_grad()
def eval_loss_acc_and_tokenf1(model, loader, criterion, pad_id):
    """
    - val_loss: teacher forcing loss
    - avg_token_acc: micro token accuracy over ALL non-pad tokens in dataset
    - token_f1: average of per-row order-free token F1 on decoded text (no CER)
    """
    model.eval()
    total_loss = 0.0

    total_correct = 0
    total_tokens = 0

    total_f1 = 0.0
    n_rows = 0

    for src, tgt in loader:
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)

        tgt_in = tgt[:, :-1]
        tgt_y  = tgt[:, 1:]

        logits = model(src, tgt_in)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), tgt_y.reshape(-1))
        total_loss += loss.item()

        preds = logits.argmax(dim=-1)  # (B,T)

        # --- avg token accuracy (micro over dataset) ---
        mask = (tgt_y != pad_id)
        total_correct += (preds[mask] == tgt_y[mask]).sum().item()
        total_tokens += mask.sum().item()

        # --- token_f1 per row (decoded) ---
        # Prepare numpy arrays for decoding
        preds_np = preds.detach().cpu().numpy()
        labels_np = tgt_y.detach().cpu().numpy()

        # replace pad id in labels where pad (already pad_id), decoding with skip_special_tokens removes them
        pred_texts = batch_decode_ids(preds_np, tokenizer, skip_special_tokens=True)
        gold_texts = batch_decode_ids(labels_np, tokenizer, skip_special_tokens=True)

        for ptxt, gtxt in zip(pred_texts, gold_texts):
            total_f1 += token_f1_order_free(ptxt, gtxt)
            n_rows += 1

    val_loss = total_loss / max(1, len(loader))
    avg_token_acc = (total_correct / total_tokens) if total_tokens > 0 else 0.0
    token_f1 = (total_f1 / n_rows) if n_rows > 0 else 0.0

    return val_loss, avg_token_acc, token_f1

In [11]:
model = TransformerSeq2Seq(VOCAB_SIZE, PAD).to(DEVICE)

criterion = nn.CrossEntropyLoss(ignore_index=PAD)
optimizer = optim.Adam(model.parameters(), lr=LR)

print("Model initialized")

Model initialized


In [12]:
def save_checkpoint(path, model, optimizer, epoch, train_loss, val_loss, avg_token_acc, token_f1):
    ckpt = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "train_loss": train_loss,
        "val_loss": val_loss,
        "avg_token_acc": avg_token_acc,
        "token_f1": token_f1,
        "special_ids": {"PAD": PAD, "SOS": SOS, "EOS": EOS, "UNK": UNK},
        "config": {
            "D_MODEL": D_MODEL,
            "NHEAD": NHEAD,
            "NUM_LAYERS": NUM_LAYERS,
            "DROPOUT": DROPOUT,
            "VOCAB_SIZE": VOCAB_SIZE,
        }
    }
    torch.save(ckpt, path)


def load_checkpoint(path, model, optimizer=None, map_location=DEVICE):
    ckpt = torch.load(path, map_location=map_location)
    model.load_state_dict(ckpt["model_state"])
    if optimizer is not None and "optimizer_state" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer_state"])
    return ckpt

In [None]:
def fmt(sec):
    sec = int(sec)
    h = sec // 3600
    m = (sec % 3600) // 60
    s = sec % 60
    if h > 0:
        return f"{h}h {m}m {s}s"
    if m > 0:
        return f"{m}m {s}s"
    return f"{s}s"


start_all = time.time()
epoch_times = []

best_val_loss = float("inf")
best_path = os.path.join(CHECKPOINT_DIR, "best.pt")

for epoch in range(1, EPOCHS + 1):
    start_epoch = time.time()

    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss, avg_token_acc, token_f1 = eval_loss_acc_and_tokenf1(model, val_loader, criterion, PAD)

    epoch_sec = time.time() - start_epoch
    epoch_times.append(epoch_sec)

    elapsed_sec = time.time() - start_all
    avg_epoch_sec = sum(epoch_times) / len(epoch_times)
    eta_sec = (EPOCHS - epoch) * avg_epoch_sec

    ckpt_path = os.path.join(CHECKPOINT_DIR, f"epoch_{epoch:03d}.pt")
    save_checkpoint(ckpt_path, model, optimizer, epoch, train_loss, val_loss, avg_token_acc, token_f1)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(best_path, model, optimizer, epoch, train_loss, val_loss, avg_token_acc, token_f1)

    print(
        f"Epoch {epoch}/{EPOCHS} | "
        f"train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | "
        f"avg_token_acc={avg_token_acc:.4f} | token_f1={token_f1:.4f} | "
        f"epoch_time={fmt(epoch_sec)} | elapsed={fmt(elapsed_sec)} | ETA={fmt(eta_sec)} | "
        f"saved={ckpt_path}"
    )

print("Best checkpoint:", best_path, "best_val_loss:", best_val_loss)

  output = torch._nested_tensor_from_mask(


Epoch 1/5 | train_loss=3.9502 | val_loss=3.3469 | avg_token_acc=0.3211 | token_f1=0.3137 | epoch_time=2m 20s | elapsed=2m 20s | ETA=9m 20s | saved=../best_weights/transcription_to_hieroglyphs_transformer_checkpoints\epoch_001.pt
Epoch 2/5 | train_loss=3.1929 | val_loss=2.9559 | avg_token_acc=0.3811 | token_f1=0.3990 | epoch_time=1m 55s | elapsed=4m 16s | ETA=6m 23s | saved=../best_weights/transcription_to_hieroglyphs_transformer_checkpoints\epoch_002.pt
Epoch 3/5 | train_loss=2.8939 | val_loss=2.7323 | avg_token_acc=0.4140 | token_f1=0.5026 | epoch_time=2m 4s | elapsed=6m 21s | ETA=4m 13s | saved=../best_weights/transcription_to_hieroglyphs_transformer_checkpoints\epoch_003.pt
