# Build source and configuration files


## 0. Remove existing directories and content


In [1]:
import shutil
from pathlib import Path

conf_root = Path("conf")
src_root = Path("src")

# Remove existing directories if they exist
shutil.rmtree(conf_root, ignore_errors=True)
shutil.rmtree(src_root, ignore_errors=True)

# Make directories
conf_root.mkdir(exist_ok=True)
src_root.mkdir(exist_ok=True)
Path("conf/tuning").mkdir(exist_ok=True)
Path("src/tokenizer").mkdir(exist_ok=True)
Path("src/data").mkdir(exist_ok=True)
Path("src/models").mkdir(exist_ok=True)
Path("src/utils").mkdir(exist_ok=True)

## 1. Configuration files


### 1.1 Default train config (seq2seq)


In [2]:
%%writefile conf/train.yaml
xlit: attention
xlit_conf:
    # Data Setting
    langx: ben
    langy: mni
    token_type: char
    db_file: db/transcribed.txt
    max_len: 100
    val_ratio: 0.25
    
    # Model Setting
    idim: 64
    odim: 48
    hidden_dim: 128
    embed_dim: 128
    elayers: 2
    dlayers: 2
    dropout: 0.25

    # Optimizer Setting
    optim: adam
    optim_conf:
        lr: 1.0e-03
        eps: 1.0e-06
        weight_decay: 0.0
    
    # Training Setting
    max_epoch: 50
    batch_size: 32
    keep_nbest_models: 5
    seed: 248

Writing conf/train.yaml


### 1.2 LSTM


In [3]:
%%writefile conf/tuning/lstm.yaml
xlit: lstm
xlit_conf:
    langx: ben
    langy: mni
    token_type: char
    db_file: db/transcribed.txt
    max_len: 100
    val_ratio: 0.25

    # Model
    idim: 64
    odim: 48
    hidden_dim: 128
    embed_dim: 128
    nlayers: 2
    dropout: 0.25

    # Optimizer
    optim: adam
    optim_conf:
        lr: 1.0e-3
        eps: 1.0e-6
        weight_decay: 0.0

    # Training
    max_epoch: 50
    batch_size: 32
    keep_nbest_models: 5
    seed: 248


Writing conf/tuning/lstm.yaml


### 1.3 CNN+Position Encoder+GRU


In [4]:
%%writefile conf/tuning/cnn.yaml
# conf/tuning/cnn_attn.yaml
xlit: cnn_attn
xlit_conf:
    langx: ben
    langy: mni
    token_type: char
    db_file: db/transcribed.txt
    max_len: 100
    val_ratio: 0.25

    # Model
    idim: 64
    odim: 47
    embed_dim: 128
    hidden_dim: 256
    kernel_size: 3
    dropout: 0.25
    teacher_forcing_ratio: 0.5

    # Optimizer
    optim: adam
    optim_conf:
        lr: 1.0e-3
        eps: 1.0e-6
        weight_decay: 0.0

    # Training
    max_epoch: 50
    batch_size: 32
    keep_nbest_models: 5
    seed: 248



Writing conf/tuning/cnn.yaml


### 1.4 Transformer


In [5]:
%%writefile conf/tuning/transformer.yaml
xlit: transformer
xlit_conf:
    langx: ben
    langy: mni
    token_type: char
    db_file: db/transcribed.txt
    max_len: 100
    val_ratio: 0.25

    # Model
    idim: 64
    odim: 48
    embed_dim: 256
    num_heads: 4 
    num_encoder_layers: 4
    num_decoder_layers: 4 
    dim_feedforward: 512 
    dropout: 0.1

    # Optimizer
    optim: adam
    optim_conf:
        lr: 1.0e-4
        eps: 1.0e-6
        weight_decay: 0.0

    # Training
    max_epoch: 20
    batch_size: 16
    keep_nbest_models: 5
    seed: 248


Writing conf/tuning/transformer.yaml


## 2. Tokenizer


In [6]:
%%writefile src/tokenizer/__init__.py
from pathlib import Path
from typing import Iterable

from ..data.utils import load_pairs

from .char_tokenizer import CharTokenizer
# from .phn_tokenizer import PhonemeTokenizer

TOKENIZER_REGISTRY = dict(
    char=CharTokenizer,
    # phn=PhonemeTokenizer
)

def _infer_token_type(tokens_file: Path) -> str:
    parts = tokens_file.stem.split("_")
    if len(parts) < 2:
        raise ValueError(f"Cannot infer token type from file name: {tokens_file.name}")
    return parts[-2]

def build_tokenizer(xs: Iterable[str], tokens_file: Path) -> None:
    token_type = _infer_token_type(tokens_file)
    tokenizer_cls = TOKENIZER_REGISTRY.get(token_type)
    if tokenizer_cls is None:
        raise ValueError(f"Unsupported token type: {token_type}")
    
    tokenizer = tokenizer_cls(xs)
    tokenizer.to_token_file(tokens_file)

def load_tokenizer(tokens_file: Path):
    if not tokens_file.exists():
        raise FileNotFoundError(f"Tokenizer file not found: {tokens_file.as_posix()}")
    
    token_type = _infer_token_type(tokens_file)
    tokenizer_cls = TOKENIZER_REGISTRY.get(token_type)
    if tokenizer_cls is None:
        raise ValueError(f"Unsupported token type: {token_type}")
    
    return tokenizer_cls.from_token_file(tokens_file)

def prepare_tokenizers(x_tokens_file, y_tokens_file, db_file):
    xs, ys = [], []
    try:
        x_tokenizer = load_tokenizer(x_tokens_file)
        y_tokenizer = load_tokenizer(y_tokens_file)
    except FileNotFoundError:
        xs, ys = load_pairs(db_file)
        build_tokenizer(xs, x_tokens_file)
        build_tokenizer(ys, y_tokens_file)
        x_tokenizer = load_tokenizer(x_tokens_file)
        y_tokenizer = load_tokenizer(y_tokens_file)
    return x_tokenizer, y_tokenizer, xs, ys

Writing src/tokenizer/__init__.py


### 2.1 Base Tokenizer


In [7]:
%%writefile src/tokenizer/base.py
from pathlib import Path
from typing import List, Dict
from abc import ABC, abstractmethod

class Tokenizer(ABC):
    def __init__(self, vocab: List[str]) -> None:
        self.idx2tok = self._create_idx2tok(tokens=vocab)
        self.tok2idx = {tok: idx for idx, tok in self.idx2tok.items()}
    
    @abstractmethod
    def _create_idx2tok(self, tokens: List[str]) -> Dict[int, str]:
        pass
    
    def __len__(self) -> int:
        return len(self.tok2idx)
    
    def encode(self, text: str, max_len: int = 100) -> List[int]: 
        tokens = (
            [self.tok2idx.get("<sos>")]
            + [self.tok2idx.get(ch) for ch in text if ch in self.tok2idx]
            + [self.tok2idx.get("<eos>")]
        )
        tokens += [self.tok2idx.get("<pad>")] * (max_len - len(tokens))
        return tokens[:max_len]
    
    def decode(self, indices: list[int]) -> str:
        return "".join(
            self.idx2tok[idx]
            for idx in indices
            if idx in self.idx2tok and idx not in {
                self.tok2idx["<pad>"],
                self.tok2idx["<sos>"],
                self.tok2idx["<eos>"]
            }
        )
    
    @classmethod
    def from_token_file(cls, path: str | Path) -> "Tokenizer":
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"Tokenizer file not found: {path.as_posix()}")
        tokenizer = cls(vocab=[])
        tokenizer.idx2tok = enum_str(path.read_text(encoding="utf-8"))
        tokenizer.tok2idx = {tok: idx for idx, tok in tokenizer.idx2tok.items()}
        
        return tokenizer

    def to_token_file(self, path: str | Path) -> None:
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True) 
        tokens = [self.idx2tok[i] for i in range(len(self))]
        path.write_text("\n".join(tokens), encoding="utf-8")
    
    
def enum_str(s: str, start: int = 0) -> dict[int, str]:
    return {i: tok for i, tok in enumerate(s.strip().splitlines(), start=start)}


Writing src/tokenizer/base.py


### 2.2 Character tokenizer


In [8]:
%%writefile src/tokenizer/char_tokenizer.py
from pathlib import Path
from typing import List

from .base import Tokenizer

class CharTokenizer(Tokenizer):
    def __init__(self, vocab: List[str]) -> None:
        special_tokens = ["<pad>", "<sos>", "<eos>"]
        tokens = special_tokens + sorted(set("".join(vocab)))
        super().__init__(vocab=tokens) 

    def _create_idx2tok(self, tokens: List[str]) -> dict[int, str]:
        return {idx: tok for idx, tok in enumerate(tokens)}

Writing src/tokenizer/char_tokenizer.py


## 3. Data Preparation and Manipulation


### 3.1 Data Utils


In [9]:
%%writefile src/data/utils.py
from pathlib import Path


def save_pairs(pairs: list[tuple], path: str | Path, sep:str="\t") -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text("\n".join([f"{x}{sep}{y}" for x, y in pairs]), encoding="utf-8")
    
    
def load_pairs(path: str | Path, sep:str="\t") -> tuple[list, list]:
    xs, ys = [], []
    for line in Path(path).read_text(encoding="utf-8").strip().split("\n"):
        x,  y, *_ = line.strip().split(sep, maxsplit=1)
        xs.append(x)
        ys.append(y)
    return xs, ys

Writing src/data/utils.py


### 3.2 XlitDataset and load_dataloader


In [10]:
%%writefile src/data/loader.py
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Union

from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch


class XlitDataset(Dataset):
    def __init__(
        self,
        xs: List[str],
        ys: List[str],
        x_tokenizer,
        y_tokenizer,
        max_len: int,
    ) -> None:
        self.xs = xs
        self.ys = ys
        self.x_tokenizer = x_tokenizer
        self.y_tokenizer = y_tokenizer
        self.max_len = max_len

    def __len__(self) -> int:
        return len(self.xs)

    def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, str]]:
        x, y = self.xs[idx], self.ys[idx]
        x_encoded = self.x_tokenizer.encode(x, self.max_len)
        y_encoded = self.y_tokenizer.encode(y, self.max_len)
        return {
            "input": torch.tensor(x_encoded, dtype=torch.long),
            "target": torch.tensor(y_encoded, dtype=torch.long),
            "input_text": x,
            "target_text": y,
        }

    def save_pairs_to_file(
        self, save_path: Union[str, Path], indices: Optional[List[int]] = None
    ) -> None:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        with open(save_path, "w", encoding="utf-8") as f:
            if indices is None:
                iterable = zip(self.xs, self.ys)
            else:
                iterable = ((self.xs[i], self.ys[i]) for i in indices)
            for x, y in iterable:
                f.write(f"{x}\t{y}\n")


def load_dataloaders(
    xs: List[str],
    ys: List[str],
    x_tokenizer,
    y_tokenizer,
    max_len: int = 100,
    batch_size: int = 32,
    val_ratio: float = 0.25,
    train_file: Optional[Union[str, Path]] = None,
    val_file: Optional[Union[str, Path]] = None,
    seed: Optional[int] = None,
) -> Tuple[DataLoader, DataLoader]:
    dataset = XlitDataset(xs, ys, x_tokenizer, y_tokenizer, max_len)
    val_size = int(val_ratio * len(dataset))
    train_size = len(dataset) - val_size

    generator = torch.Generator()
    if seed is not None:
        generator.manual_seed(seed)

    train_dataset, val_dataset = random_split(
        dataset, [train_size, val_size], generator=generator
    )

    # Save data pairs if requested
    if train_file:
        dataset.save_pairs_to_file(train_file, indices=train_dataset.indices)
    if val_file:
        dataset.save_pairs_to_file(val_file, indices=val_dataset.indices)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader


Writing src/data/loader.py


## 4. Models


In [11]:
%%writefile src/models/__init__.py
from .attn import AttnSeq2Seq
from .lstm import LSTMSeq2Seq
from .cnn import CNNSeq2SeqAttn
from .transformer import TransformerSeq2Seq

MODEL_REGISTRY = dict(
    attention=AttnSeq2Seq,
    lstm=LSTMSeq2Seq,
    cnn_attn=CNNSeq2SeqAttn,
    transformer=TransformerSeq2Seq,
)


def load_model(model_name, model_conf, device):
    if model_name not in MODEL_REGISTRY:
        raise ValueError(f"Model '{model_name}' not supported.")
    model_cls = MODEL_REGISTRY[model_name]
    model = model_cls(model_conf, device)
    return model


Writing src/models/__init__.py


### 4.0 Base Model


In [12]:

%%writefile src/models/base.py
import torch
import torch.nn as nn

class XlitModel(nn.Module):
    def __init__(self, model_conf: dict, device: torch.device) -> None:
            super().__init__()
            self.device = device
            self.max_len = model_conf["max_len"]
            self.sos_token = model_conf["sos_token"]
            self.eos_token = model_conf["eos_token"]
            self.pad_token = model_conf["pad_token"]
            


Writing src/models/base.py


### 4.1. Attention Seq2Seq Model


In [13]:
%%writefile src/models/attn.py
from typing import Optional
import torch
from torch import nn


from src.models.base import XlitModel


class Encoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        embed_dim: int,
        hidden_dim: int,
        num_layers: int,
        dropout: float,
    ):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.rnn = nn.LSTM(
            embed_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        embedded = self.dropout(self.embedding(x))
        outputs, (hidden, cell) = self.rnn(embedded)
        return outputs, hidden, cell


class Attention(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(
        self,
        hidden: torch.Tensor,
        encoder_outputs: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ):
        batch_size, seq_len, _ = encoder_outputs.size()
        hidden = hidden[-1].unsqueeze(1).repeat(1, seq_len, 1)
        combined = torch.cat([hidden, encoder_outputs], dim=2)
        energy = torch.tanh(self.attn(combined))
        attention = self.v(energy).squeeze(2)

        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)

        return torch.softmax(attention, dim=1)


class Decoder(nn.Module):
    def __init__(
        self,
        output_dim: int,
        embed_dim: int,
        hidden_dim: int,
        num_layers: int,
        dropout: float,
    ):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.rnn = nn.LSTM(
            embed_dim + hidden_dim,
            hidden_dim,
            num_layers,
            dropout=dropout,
            batch_first=True,
        )
        self.fc_out = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.attention = Attention(hidden_dim)

    def forward(
        self,
        inp: torch.Tensor,
        hidden: torch.Tensor,
        cell: torch.Tensor,
        encoder_outputs: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        inp = inp.unsqueeze(1)  # (batch_size, 1)
        embedded = self.dropout(self.embedding(inp))  # (batch_size, 1, embed_dim)
        # (batch_size, seq_len)
        attn_weights = self.attention(hidden, encoder_outputs, mask)
        # (batch_size, 1, hidden_dim)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
        # (batch_size, 1, embed+hidden)
        rnn_input = torch.cat((embedded, context), dim=2)
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        prediction = self.fc_out(
            torch.cat((output.squeeze(1), context.squeeze(1)), dim=1)
        )  # (batch_size, output_dim)
        return prediction, hidden, cell


class AttnSeq2Seq(XlitModel):
    def __init__(self, model_conf: dict, device: torch.device):
        super().__init__(model_conf, device)
        self.encoder = Encoder(
            model_conf["idim"],
            model_conf["embed_dim"],
            model_conf["hidden_dim"],
            model_conf["elayers"],
            model_conf["dropout"],
        ).to(device)

        self.decoder = Decoder(
            model_conf["odim"],
            model_conf["embed_dim"],
            model_conf["hidden_dim"],
            model_conf["dlayers"],
            model_conf["dropout"],
        ).to(device)

        self.teacher_forcing_ratio = model_conf.get("teacher_forcing_ratio", 0.5)

    def forward(
        self,
        x: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        max_len: Optional[int] = None,
    ) -> torch.Tensor:

        max_len = max_len or self.max_len
        batch_size = x.size(0)
        y_vocab_size = self.decoder.fc_out.out_features
        target_len = y.size(1) if y is not None else max_len

        assert target_len is not None, "max_len must be provided when y is None"

        mask = x != self.pad_token  # (batch_size, seq_len)
        encoder_outputs, hidden, cell = self.encoder(x)
        outputs = torch.zeros(batch_size, target_len, y_vocab_size, device=self.device)

        inp = (
            y[:, 0]
            if y is not None
            else torch.full(
                (batch_size,), self.sos_token, dtype=torch.long, device=self.device
            )
        )

        for t in range(1, target_len):
            output, hidden, cell = self.decoder(
                inp, hidden, cell, encoder_outputs, mask
            )
            outputs[:, t] = output
            top1 = output.argmax(1)

            if y is not None:
                teacher_force = (
                    torch.rand(1, device=self.device) < self.teacher_forcing_ratio
                )
                inp = torch.where(
                    teacher_force.unsqueeze(1), y[:, t].unsqueeze(1), top1.unsqueeze(1)
                ).squeeze(1)
            else:
                inp = top1
                if self.eos_token is not None and (inp == self.eos_token).all():
                    break

        return outputs


Writing src/models/attn.py


### 4.2 LSTM Seq2Seq


In [14]:
%%writefile src/models/lstm.py
import torch
from torch import nn

from .base import XlitModel

class LSTMSeq2Seq(XlitModel):
    def __init__(self, model_conf: dict, device: torch.device):
        super().__init__(model_conf, device)
        self.embedding = nn.Embedding(model_conf["idim"], model_conf["embed_dim"])
        self.encoder = nn.LSTM(
            model_conf["embed_dim"],
            model_conf["hidden_dim"],
            num_layers=model_conf["nlayers"],
            dropout=model_conf["dropout"],
            batch_first=True,
        )
        self.decoder = nn.LSTM(
            model_conf["embed_dim"],
            model_conf["hidden_dim"],
            num_layers=model_conf["nlayers"],
            dropout=model_conf["dropout"],
            batch_first=True,
        )
        self.output_layer = nn.Linear(model_conf["hidden_dim"], model_conf["odim"])
        self.dropout = nn.Dropout(model_conf["dropout"])
        self.teacher_forcing_ratio = model_conf.get("teacher_forcing_ratio", 0.5)

    def forward(self, x: torch.Tensor, y: torch.Tensor = None, max_len: int = None):
        batch_size = x.size(0)
        embedded = self.dropout(self.embedding(x))
        _, (hidden, cell) = self.encoder(embedded)

        if y is not None:
            target_len = y.size(1)
            outputs = torch.zeros(batch_size, target_len, self.output_layer.out_features, device=self.device)
            inp = y[:, 0]
        else:
            target_len = max_len or self.max_len
            assert target_len is not None, "max_len must be provided when y is None"
            outputs = torch.zeros(batch_size, target_len, self.output_layer.out_features, device=self.device)
            inp = torch.full((batch_size,), self.sos_token, dtype=torch.long, device=self.device)

        for t in range(1, target_len):
            embedded_input = self.dropout(self.embedding(inp)).unsqueeze(1)
            output, (hidden, cell) = self.decoder(embedded_input, (hidden, cell))
            output = self.output_layer(output.squeeze(1))
            outputs[:, t] = output

            top1 = output.argmax(1)

            if y is not None and torch.rand(1).item() < self.teacher_forcing_ratio:
                inp = y[:, t]
            else:
                inp = top1

        return outputs


Writing src/models/lstm.py


### 4.3 CNN Seq2Seq + PE + GRU


In [15]:
%%writefile src/models/positional_encoding.py
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # shape (1, max_len, embed_dim)

    def forward(self, x: torch.Tensor):
        # x shape: (batch_size, seq_len, embed_dim)
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len].to(x.device)

Writing src/models/positional_encoding.py


In [16]:
%%writefile src/models/cnn.py
# src/models/cnn.py
import torch
import torch.nn as nn
import math
from .base import XlitModel
from .positional_encoding import PositionalEncoding


class ConvEncoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim, kernel_size, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.pe = PositionalEncoding(embed_dim)
        self.conv = nn.Conv1d(embed_dim, hidden_dim, kernel_size, padding=kernel_size // 2)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: (B, T)
        embedded = self.embedding(x)            # (B, T, E)
        embedded = self.pe(embedded)            # (B, T, E)
        embedded = self.dropout(embedded)
        conv_input = embedded.transpose(1, 2)   # (B, E, T)
        conv_output = self.conv(conv_input)     # (B, H, T)
        return conv_output.transpose(1, 2)      # (B, T, H)


class ConvAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim):
        super().__init__()
        self.scale = math.sqrt(enc_dim)

    def forward(self, query, encoder_outputs):
        # query: (B, T_dec, D), encoder_outputs: (B, T_enc, D)
        scores = torch.bmm(query, encoder_outputs.transpose(1, 2))  # (B, T_dec, T_enc)
        attn_weights = torch.softmax(scores / self.scale, dim=2)
        context = torch.bmm(attn_weights, encoder_outputs)          # (B, T_dec, D)
        return context, attn_weights


class ConvDecoder(nn.Module):
    def __init__(self, output_dim, embed_dim, hidden_dim, kernel_size, dropout):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.pe = PositionalEncoding(embed_dim)
        self.conv = nn.Conv1d(embed_dim, hidden_dim * 2, kernel_size, padding=kernel_size // 2)
        self.attn = ConvAttention(hidden_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, encoder_outputs):
        # tgt: (B, T)
        embedded = self.embedding(tgt)          # (B, T, E)
        embedded = self.pe(embedded)
        embedded = self.dropout(embedded)
        conv_input = embedded.transpose(1, 2)   # (B, E, T)
        conv_output = self.conv(conv_input)     # (B, 2H, T)
        conv_output = conv_output.transpose(1, 2)  # (B, T, 2H)

        H = conv_output.size(-1) // 2
        out = conv_output[:, :, :H]
        gate = torch.sigmoid(conv_output[:, :, H:])
        out = out * gate                        # (B, T, H)

        context, _ = self.attn(out, encoder_outputs)  # (B, T, H)
        combined = torch.cat([out, context], dim=2)   # (B, T, 2H)
        output = self.fc_out(combined)                # (B, T, vocab_size)
        return output

class CNNSeq2SeqAttn(XlitModel):
    def __init__(self, model_conf: dict, device: torch.device):
        super().__init__(model_conf, device)
        self.encoder = ConvEncoder(
            input_dim=model_conf["idim"],
            embed_dim=model_conf["embed_dim"],
            hidden_dim=model_conf["hidden_dim"],
            kernel_size=model_conf["kernel_size"],
            dropout=model_conf["dropout"],
        )

        self.decoder = ConvDecoder(
            output_dim=model_conf["odim"],
            embed_dim=model_conf["embed_dim"],
            hidden_dim=model_conf["hidden_dim"],
            kernel_size=model_conf["kernel_size"],
            dropout=model_conf["dropout"],
        )

        self.teacher_forcing_ratio = model_conf.get("teacher_forcing_ratio", 0.5)

    def forward(self, x, y=None, max_len=None):
        batch_size = x.size(0)
        max_len = max_len or self.max_len
        target_len = y.size(1) if y is not None else max_len
        vocab_size = self.decoder.fc_out.out_features

        encoder_outputs = self.encoder(x)  # (B, T_src, H)
        outputs = torch.zeros(batch_size, target_len, vocab_size, device=self.device)

        # Initialize first decoder input
        inp = (
            y[:, 0] if y is not None
            else torch.full((batch_size,), self.sos_token, dtype=torch.long, device=self.device)
        )

        tgt_tokens = [inp]  # List of tokens to build decoder input

        for t in range(1, target_len):
            decoder_input = torch.stack(tgt_tokens, dim=1)  # (B, t)
            logits = self.decoder(decoder_input, encoder_outputs)  # (B, t, vocab)
            output_t = logits[:, -1, :]  # (B, vocab)
            outputs[:, t] = output_t

            top1 = output_t.argmax(1)

            if y is not None:
                teacher_force = torch.rand(batch_size, device=self.device) < self.teacher_forcing_ratio
                inp = torch.where(teacher_force, y[:, t], top1)
            else:
                inp = top1

            tgt_tokens.append(inp)

            if self.eos_token is not None and (inp == self.eos_token).all():
                break

        return outputs



Writing src/models/cnn.py


### 4.4 Transformer


In [17]:
%%writefile src/models/transformer.py
import torch
import torch.nn as nn

from .base import XlitModel
from .positional_encoding import PositionalEncoding
    

class TransformerSeq2Seq(XlitModel):
    def __init__(self, model_conf: dict, device: torch.device):
        super().__init__(model_conf, device)

        self.embed_dim = model_conf["embed_dim"]
        self.encoder_embed = nn.Embedding(model_conf["idim"], self.embed_dim)
        self.decoder_embed = nn.Embedding(model_conf["odim"], self.embed_dim)

        self.pos_encoder = PositionalEncoding(self.embed_dim)
        self.pos_decoder = PositionalEncoding(self.embed_dim)

        self.transformer = nn.Transformer(
            d_model=self.embed_dim,
            nhead=model_conf["num_heads"],
            num_encoder_layers=model_conf["num_encoder_layers"],
            num_decoder_layers=model_conf["num_decoder_layers"],
            dim_feedforward=model_conf["dim_feedforward"],
            dropout=model_conf["dropout"],
            batch_first=True,
        )

        self.generator = nn.Linear(self.embed_dim, model_conf["odim"])
        self.dropout = nn.Dropout(model_conf["dropout"])
        self.teacher_forcing_ratio = model_conf.get("teacher_forcing_ratio", 0.5)

    def forward(self, src, tgt=None, max_len=None):
        batch_size, src_len = src.shape
        max_len = max_len or self.max_len
        tgt_len = tgt.size(1) if tgt is not None else max_len

        src_mask = None
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_len).to(self.device)

        src_emb = self.dropout(self.pos_encoder(self.encoder_embed(src)))
        memory = self.transformer.encoder(src_emb, mask=src_mask)

        outputs = torch.zeros(batch_size, tgt_len, self.generator.out_features).to(self.device)

        ys = (
            tgt[:, 0] if tgt is not None else
            torch.full((batch_size,), self.sos_token, dtype=torch.long, device=self.device)
        )
        ys = ys.unsqueeze(1)

        for t in range(1, tgt_len):
            tgt_emb = self.dropout(self.pos_decoder(self.decoder_embed(ys)))
            out = self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask[:t, :t])
            output = self.generator(out[:, -1])
            outputs[:, t] = output
            top1 = output.argmax(1).unsqueeze(1)

            if tgt is not None:
                use_teacher = torch.rand(batch_size, device=self.device) < self.teacher_forcing_ratio
                next_input = torch.where(use_teacher, tgt[:, t], top1.squeeze(1))
            else:
                next_input = top1.squeeze(1)
                if self.eos_token is not None and (next_input == self.eos_token).all():
                    break

            ys = torch.cat([ys, next_input.unsqueeze(1)], dim=1)

        return outputs


Writing src/models/transformer.py


## 5. Utilities


In [18]:

%%writefile src/utils/__init__.py
from .seed import set_seed
from .logger import setup_logger
from .plot import plot_metrics
from .save import save_models, save_best_predictions

Writing src/utils/__init__.py


### 5.1 Set Seed


In [19]:
%%writefile src/utils/seed.py
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Writing src/utils/seed.py


### 5.2 Logger


In [20]:
%%writefile src/utils/logger.py
import logging
import sys
import shutil

def setup_logger(log_file=None, backup_file=None):
    logger = logging.getLogger("XlitTask")
    logger.setLevel(logging.INFO)
    logger.propagate = False
    
    formatter = logging.Formatter("%(asctime)s (%(module)s:%(lineno)d)  %(levelname)s: %(message)s")

    if log_file.exists():
        shutil.copy(log_file, backup_file)
        log_file.unlink()
    
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    sh = logging.StreamHandler(sys.stdout)
    sh.setLevel(logging.INFO)
    sh.setFormatter(formatter)
    
    if logger.hasHandlers():
        logger.handlers.clear()
        
    logger.addHandler(fh)
    logger.addHandler(sh)
    
    return logger


Writing src/utils/logger.py


### 5.3 Plot Utils


In [21]:
%%writefile src/utils/plot.py
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns


def plot_metrics(image_dir: Path, train_losses, val_losses, val_cers, val_accs):
    image_dir.mkdir(parents=True, exist_ok=True)

    sns.set(style="whitegrid", font_scale=1.4)
    palette = sns.color_palette("Set2")

    # Loss plot
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label="Train Loss", marker='o', color=palette[0])
    plt.plot(val_losses, label="Val Loss", marker='s', color=palette[1])
    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.title("Training and Validation Loss", fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(image_dir / "losses.png", dpi=300)
    plt.close()

    # CER plot
    plt.figure(figsize=(10, 6))
    plt.plot(val_cers, label="Val CER", marker='^', color=palette[2])
    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Character Error Rate", fontsize=14)
    plt.title("Validation CER Over Epochs", fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(image_dir / "cer.png", dpi=300)
    plt.close()

    # Word Accuracy plot
    plt.figure(figsize=(10, 6))
    plt.plot(val_accs, label="Val Word Accuracy", marker='D', color=palette[3])
    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Word Accuracy", fontsize=14)
    plt.title("Validation Word Accuracy Over Epochs", fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(image_dir / "wa.png", dpi=300)
    plt.close()


Writing src/utils/plot.py


### 5.4 Save models


In [22]:
%%writefile src/utils/save.py
from pathlib import Path
from typing import Tuple, Dict, List
import copy
import heapq
import logging

import torch


def save_models(
    exp_dir: Path,
    logger: logging.Logger,
    current_model: torch.nn.Module,
    model_conf: dict,
    device: torch.device,
    epoch: int,
    max_epoch: int,
    avg_train_loss: float,
    avg_val_loss: float,
    wa_score: float,
    best_train_loss: float,
    best_val_loss: float,
    best_wa: float,
    best_models: list,
    saved_epochs: set,
    n_best: int,
) -> Tuple[float, float, float, list, set]:
    # Save best training loss
    if avg_train_loss < best_train_loss:
        best_train_loss = avg_train_loss
        torch.save(current_model.state_dict(), exp_dir / "train.loss.best.pth")
        logger.info(f"Saved best training loss model at epoch {epoch}.")

    # Save best validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(current_model.state_dict(), exp_dir / "val.loss.best.pth")
        logger.info(f"Saved best validation loss model at epoch {epoch}.")

    # Save best word accuracy
    if wa_score > best_wa:
        best_wa = wa_score
        torch.save(current_model.state_dict(), exp_dir / "wa.best.pth")
        logger.info(f"Saved best word accuracy model at epoch {epoch}.")

    # Save latest
    torch.save(current_model.state_dict(), exp_dir / "latest.pth")
    logger.info(f"Saved latest model at epoch {epoch}.")

    # Save top-N best by validation loss using max-heap (invert val_loss)
    if epoch not in saved_epochs:
        if len(best_models) < n_best:
            heapq.heappush(
                best_models,
                (-avg_val_loss, epoch, copy.deepcopy(current_model.state_dict())),
            )
            saved_epochs.add(epoch)
            torch.save(current_model.state_dict(), exp_dir / f"{epoch}epoch.pth")
        else:
            worst_neg_loss, worst_epoch, _ = best_models[0]
            if -avg_val_loss > worst_neg_loss:
                removed = heapq.heappushpop(
                    best_models,
                    (-avg_val_loss, epoch, copy.deepcopy(current_model.state_dict())),
                )
                saved_epochs.discard(removed[1])
                saved_epochs.add(epoch)

                # Remove old worst model
                worst_path = exp_dir / f"{removed[1]}epoch.pth"
                if worst_path.exists():
                    worst_path.unlink()
                    logger.info(f"Removed evicted model from epoch {removed[1]}.")

                # Save new top model
                torch.save(current_model.state_dict(), exp_dir / f"{epoch}epoch.pth")

    # Save averaged model at end
    if epoch == max_epoch:
        avg_model = average_model_weights(
            best_models, current_model, model_conf, device
        )
        torch.save(avg_model.state_dict(), exp_dir / "val.loss.ave.pth")
        logger.info("Saved averaged model at the end of training.")

    return best_train_loss, best_val_loss, best_wa, best_models, saved_epochs


def average_model_weights(
    model_heap: list[tuple[float, int, dict]],
    current_model: torch.nn.Module,
    model_conf: Dict,
    device: torch.device,
) -> torch.nn.Module:
    n_models = len(model_heap)
    assert n_models > 0, "No models to average."
    avg_state_dict = copy.deepcopy(model_heap[0][2])

    for key in avg_state_dict.keys():
        for i in range(1, n_models):
            avg_state_dict[key] += model_heap[i][2][key]
        avg_state_dict[key] /= n_models

    avg_model = copy.deepcopy(current_model)
    avg_model.load_state_dict(avg_state_dict)
    avg_model.to(device)
    return avg_model


def save_best_predictions(
    exp_dir: Path, xs: List[str], true_texts: List[str], pred_texts: List[str]
) -> None:
    correct_flags = ["✔" if p == t else "✘" for p, t in zip(pred_texts, true_texts)]
    results = [
        f"{x}\t{t}\t{p}\t{c}"
        for x, t, p, c in zip(xs, true_texts, pred_texts, correct_flags)
    ]
    decode_path = exp_dir / "decode/wa.best.decode"
    decode_path.parent.mkdir(parents=True, exist_ok=True)
    decode_path.write_text("\n".join(results), encoding="utf-8")


Writing src/utils/save.py


## 6. Xlit Trainer


In [23]:
%%writefile src/xlit_task.py
import time
from pathlib import Path
from typing import Optional, List, Tuple

import torch
import yaml
from jiwer import cer, wer
from torch.utils.tensorboard import SummaryWriter

from .tokenizer import prepare_tokenizers
from .data.loader import load_dataloaders
from .models import load_model
from .utils import (
    set_seed,
    setup_logger,
    plot_metrics,
    save_models,
    save_best_predictions,
)


class XlitTask:
    def __init__(
        self, conf_file: str | Path = "train.yaml", ckpt_file: str | Path = None
    ) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._load_config(Path("conf") / conf_file)
        set_seed(self.conf["seed"])
        self.exp_dir = (
            Path("exp")
            / f"xlit_train_{self.model_name}_{self.token_type}_{self.lang_pair}"
        )
        self.data_dir = self.exp_dir / "data"
        self.data_dir.mkdir(parents=True, exist_ok=True)
        self.model = self._build_model(ckpt_file).to(self.device)

    @classmethod
    def from_pretrained(cls, ckpt_file: str | Path, conf_file: str | Path = None):
        ckpt_file = Path(ckpt_file)
        if not ckpt_file.exists():
            raise FileNotFoundError(
                f"Checkpoint file not found: {ckpt_file.as_posix()}"
            )
        return cls(conf_file, ckpt_file)

    def __call__(self, x: str, max_len: Optional[int] = None) -> str:
        return self.infer(x, max_len)

    def _load_config(self, conf_path) -> None:
        with open(conf_path, "r") as f:
            content = yaml.safe_load(f)
        self.model_name, self.conf = content["xlit"], content["xlit_conf"]
        self.token_type = self.conf["token_type"]
        self.langx, self.langy = self.conf["langx"], self.conf["langy"]
        self.lang_pair = f"{self.langx}_{self.langy}"

    def _build_model(self, ckpt_file: Optional[str | Path]) -> torch.nn.Module:
        self.x_tokenizer, self.y_tokenizer, self.xs, self.ys = prepare_tokenizers(
            x_tokens_file=self.data_dir / f"{self.langx}_{self.token_type}_tokens.txt",
            y_tokens_file=self.data_dir / f"{self.langy}_{self.token_type}_tokens.txt",
            db_file=self.conf["db_file"],
        )
        self.conf["idim"], self.conf["odim"] = len(self.x_tokenizer), len(
            self.y_tokenizer
        )
        self.conf["pad_token"] = self.y_tokenizer.tok2idx.get("<pad>", 0)
        self.conf["sos_token"] = self.y_tokenizer.tok2idx.get("<sos>", 1)
        self.conf["eos_token"] = self.y_tokenizer.tok2idx.get("<eos>", 2)
        self.conf["max_len"] = self.conf.get("max_len", 100)
        model = load_model(self.model_name, self.conf, device=self.device)
        if ckpt_file:
            model.load_state_dict(torch.load(ckpt_file, map_location=self.device))
        return model

    def infer(self, x: str, max_len: Optional[int] = None) -> str:
        max_len = max_len or self.conf.get("max_len", 100)
        tokenized_x = self.x_tokenizer.encode(x, max_len=max_len)
        input_tensor = torch.tensor(tokenized_x).unsqueeze(0).to(self.device)
        with torch.no_grad():
            y_pred = self.model(
                input_tensor,
                max_len=max_len,
            )
        predicted_ids = y_pred.argmax(dim=2)
        return self.y_tokenizer.decode(predicted_ids[0].tolist())

    def _prepare_data(self) -> None:
        (
            self.train_loader,
            self.val_loader,
        ) = load_dataloaders(
            self.xs,
            self.ys,
            self.x_tokenizer,
            self.y_tokenizer,
            max_len=self.conf["max_len"],
            batch_size=self.conf["batch_size"],
            val_ratio=self.conf["val_ratio"],
            train_file=self.data_dir / f"train_{self.lang_pair}.txt",
            val_file=self.data_dir / f"val_{self.lang_pair}.txt",
            seed=self.conf["seed"],
        )

    def _compute_loss(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        return self.criterion(
            y_pred[:, 1:].reshape(-1, y_pred.shape[2]),
            y_true[:, 1:].reshape(-1),
        )

    def _train_step(self, batch) -> float:
        x, y = batch["input"].to(self.device), batch["target"].to(self.device)
        self.optimizer.zero_grad()
        y_pred = self.model(x, y)
        loss = self._compute_loss(y_pred, y)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def _val_step(self, batch) -> Tuple[float, List[str], List[str]]:
        x, y = batch["input"].to(self.device), batch["target"].to(self.device)
        with torch.no_grad():
            y_pred = self.model(x, y)
            loss = self._compute_loss(y_pred, y)
            pred_texts = [
                self.y_tokenizer.decode(seq.tolist())
                for seq in y_pred.argmax(dim=-1).cpu()
            ]
            true_texts = [self.y_tokenizer.decode(seq.tolist()) for seq in y]
        return loss.item(), pred_texts, true_texts

    def _init_logs(self):
        self.logger.info("Data Information:")
        self.logger.info(f"Batch size: {self.conf['batch_size']}")
        self.logger.info(
            f"[Training] Data size: {len(self.train_loader.dataset)} to {len(self.train_loader)} batches"
        )
        self.logger.info(
            f"[Validation] Data size: {len(self.val_loader.dataset)} to {len(self.val_loader)} batches"
        )
        self.logger.info("Token information")
        self.logger.info(
            f"[{self.langx}] Tokenizer loaded with {len(self.x_tokenizer)} tokens."
        )
        self.logger.info(
            f"[{self.langy}] Tokenizer loaded with {len(self.y_tokenizer)} tokens."
        )
        total_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad
        )
        self.logger.info(f"Model information:\n{self.model}")
        self.logger.info(f"Total trainable parameters: {total_params}")
        self.logger.info(f"Experiment directory: {self.exp_dir.as_posix()}")
        self.logger.info(f"Optimizer: {self.optimizer.__class__.__name__}")
        self.logger.info(f"Loss criterion: {self.criterion.__class__.__name__}")

    def train(self):
        start_time = time.time()
        image_dir = self.exp_dir / "images"

        self.tb_writer = SummaryWriter(log_dir=self.exp_dir / "tensorboard")
        self.logger = setup_logger(
            log_file=self.exp_dir / "train.log",
            backup_file=self.exp_dir / "train.old.log",
        )
        self._prepare_data()

        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.conf["optim_conf"]["lr"]
        )
        self.criterion = torch.nn.CrossEntropyLoss()
        self._init_logs()

        max_epoch = self.conf.get("max_epoch", 100)
        n_best = self.conf.get("keep_nbest_models", 5)

        best_val_loss = float("inf")
        best_wa = -float("inf")
        best_train_loss = float("inf")

        best_models = []
        saved_epochs = set()

        train_losses, val_losses, val_cers, val_accs = [], [], [], []

        self.logger.info(f"Started training {self.model_name} model on [{self.device}]")
        start_epoch = 1
        try:
            for epoch in range(start_epoch, max_epoch + 1):
                epoch_start_time = time.time()
                self.logger.info(f"Epoch {epoch}/{max_epoch}")

                # Training
                self.model.train()
                avg_train_loss = sum(
                    self._train_step(batch) for batch in self.train_loader
                ) / len(self.train_loader)
                train_losses.append(avg_train_loss)

                self.tb_writer.add_scalar("Loss/Train", avg_train_loss, epoch)
                self.logger.info(f"Train Loss: {avg_train_loss:.4f}")

                # Validation
                self.model.eval()
                val_loss = 0.0
                pred_texts, true_texts = [], []

                val_xs = []
                for batch in self.val_loader:
                    val_xs.extend(
                        [self.x_tokenizer.decode(x.tolist()) for x in batch["input"]]
                    )
                    loss, preds, trues = self._val_step(batch)
                    val_loss += loss
                    pred_texts.extend(preds)
                    true_texts.extend(trues)

                avg_val_loss = val_loss / len(self.val_loader)
                cer_score = cer(true_texts, pred_texts)
                wa_score = 1 - wer(true_texts, pred_texts)
                if best_wa < wa_score:
                    save_best_predictions(
                        self.exp_dir,
                        val_xs,
                        true_texts,
                        pred_texts,
                    )
                    self.logger.info(
                        f"Saved best word accuracy predictions at epoch {epoch}."
                    )

                val_losses.append(avg_val_loss)
                val_cers.append(cer_score)
                val_accs.append(wa_score)

                self.tb_writer.add_scalar("Loss/Val", avg_val_loss, epoch)
                self.tb_writer.add_scalar("CER/Val", cer_score, epoch)
                self.tb_writer.add_scalar("Word Accuracy/Val", wa_score, epoch)
                self.logger.info(
                    f"Val Loss: {avg_val_loss:.4f}, CER: {cer_score:.4f}, Word Accuracy: {wa_score:.4f}"
                )

                epoch_duration = time.time() - epoch_start_time
                elapsed = time.time() - start_time
                remaining_epochs = max_epoch - epoch
                eta = epoch_duration * remaining_epochs
                eta_sec = int(eta)
                eta_h, rem = divmod(eta_sec, 3600)
                eta_m, eta_s = divmod(rem, 60)
                self.logger.info(
                    f"Epoch {epoch} duration: {epoch_duration:.2f} sec | ETA: {eta_h:02d}:{eta_m:02d}:{eta_s:02d}"
                )

                best_train_loss, best_val_loss, best_wa, best_models, saved_epochs = (
                    save_models(
                        self.exp_dir,
                        self.logger,
                        self.model,
                        self.conf,
                        self.device,
                        epoch,
                        max_epoch,
                        avg_train_loss,
                        avg_val_loss,
                        wa_score,
                        best_train_loss,
                        best_val_loss,
                        best_wa,
                        best_models,
                        saved_epochs,
                        n_best,
                    )
                )
                plot_metrics(image_dir, train_losses, val_losses, val_cers, val_accs)
                self.logger.info(
                    f"Saved [loss, wa, cer] curves at {image_dir.as_posix()}"
                )
        finally:
            self.tb_writer.close()

        total_time = time.time() - start_time
        h, rem = divmod(int(total_time), 60)
        m, s = divmod(rem, 60)
        self.logger.info(f"Training completed in {h:02d}:{m:02d}:{s:02d}")


Writing src/xlit_task.py
