
# English → Japanese Transformer from Scratch

This notebook trains a small sequence-to-sequence **Transformer** to translate **English → Japanese** from a large database of japanese sentences and their respective english translations.

**What this notebook does:**
1. Loads & cleans a .TSV file.
2. Trains a shared **SentencePiece** tokenizer suited for Japanese.
3. Builds a PyTorch **Transformer** using `nn.Transformer` for seq2seq.
4. Trains with teacher forcing, validates each epoch, and reports **BLEU**.
5. Saves artifacts (tokenizer + model) and provides a `translate()` helper.

In [None]:

# If running locally, uncomment to install dependencies.
# !pip install --upgrade torch sentencepiece sacrebleu matplotlib tqdm

In [7]:

import os, io, csv, math, random, unicodedata, re, time
from pathlib import Path
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import sentencepiece as spm
import sacrebleu
from tqdm import tqdm

# ============== SETTINGS ==============
# Path to .TSV file
DATA_TSV = "Sentence pairs in Japanese-English - 2025-08-27.tsv"

# Training / model hyperparameters
VOCAB_SIZE      = 16000
MAX_LEN         = 128
BATCH_SIZE      = 64
EMBED_DIM       = 512
FF_DIM          = 2048
NHEAD           = 8
ENC_LAYERS      = 4
DEC_LAYERS      = 4
DROPOUT         = 0.1
EPOCHS          = 20
LR              = 3e-4
LABEL_SMOOTH    = 0.1

VAL_FRACTION    = 0.05
RANDOM_SEED     = 42

# Output directory
ARTIFACTS_DIR   = Path("artifacts")
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

# Determinism
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Print whether CPU or CUDA is being used
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cuda


In [8]:

def looks_japanese(s: str) -> bool:
    return bool(re.search(r"[\u3040-\u30ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]", s))

def looks_english(s: str) -> bool:
    return (bool(re.search(r"[A-Za-z]", s)) and not looks_japanese(s))

def clean_text(s: str, lang: str) -> str:
    s = unicodedata.normalize("NFKC", s or "").strip()
    s = re.sub(r"\s+", " ", s)
    return s

def pick_text(cols, predicate):
    best = ""
    for c in cols:
        c = c.strip()
        if not c or c.isdigit():
            continue
        if predicate(c):
            if len(c) > len(best):
                best = c
    return best

def read_tsv_guess_columns(path: str) -> List[Tuple[str, str]]:
    pairs = []
    with io.open(path, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t")
        for row in reader:
            if not row: 
                continue
            # Clean each column
            cols = [clean_text(x, "en") for x in row]

            # Try to pick Japanese and English from any column
            ja = pick_text(cols, looks_japanese)
            en = pick_text(cols, looks_english)

            # Fallback if detection is inconclusive
            if not ja or not en:
                if len(cols) >= 2:
                    c1, c2 = cols[0], cols[1]
                    if len(cols) >= 3:
                        c3 = cols[2]
                    else:
                        c3 = ""
                    if looks_japanese(c1) and (looks_english(c2) or looks_english(c3)):
                        ja = ja or c1
                        en = en or (c2 if looks_english(c2) else c3)
                    elif looks_english(c1) and (looks_japanese(c2) or looks_japanese(c3)):
                        en = en or c1
                        ja = ja or (c2 if looks_japanese(c2) else c3)

            if ja and en:
                pairs.append((en, ja))
    return pairs

pairs = read_tsv_guess_columns(DATA_TSV)
print(f"Loaded {len(pairs):,} sentence pairs (after picking EN/JA columns).")
if not pairs:
    raise SystemExit("No usable EN–JA pairs found; check the file format/path.")

Loaded 277,280 sentence pairs (after picking EN/JA columns).


In [9]:

# Basic filters: drop duplicates, overly long pairs, extreme length ratios
def tokenize_whitespace(s: str) -> List[str]:
    return s.split()

def filter_pairs(pairs, max_len=256, ratio=3.0):
    seen = set()
    out = []
    for en, ja in pairs:
        key = (en, ja)
        if key in seen: 
            continue
        seen.add(key)
        # crude length measures before subwording
        en_len = len(tokenize_whitespace(en))
        ja_len = len(ja)
        if en_len == 0 or ja_len == 0:
            continue
        if en_len > max_len or ja_len > max_len*2:  # allow more chars for JA
            continue
        if max(en_len, ja_len) / max(1, min(en_len, ja_len)) > ratio:
            continue
        out.append((en, ja))
    return out

pairs = filter_pairs(pairs, max_len=MAX_LEN*2, ratio=4.0)
print(f"After filtering: {len(pairs):,} pairs")

# Train/val split
random.shuffle(pairs)
n_val = max(1, int(len(pairs) * VAL_FRACTION))
val_pairs = pairs[:n_val]
train_pairs = pairs[n_val:]
print(f"Train: {len(train_pairs):,}  |  Val: {len(val_pairs):,}")

After filtering: 270,573 pairs
Train: 257,045  |  Val: 13,528


In [10]:

# Write combined text file for SentencePiece
spm_input_path = ARTIFACTS_DIR / "spm_train.txt"
with io.open(spm_input_path, "w", encoding="utf-8") as f:
    for en, ja in train_pairs:
        f.write(en + "\n")
        f.write(ja + "\n")

spm_prefix = str(ARTIFACTS_DIR / "spm_enja")
spm_cmd = f"--input={spm_input_path} --model_prefix={spm_prefix} --vocab_size={VOCAB_SIZE} " \
          f"--character_coverage=0.9995 --model_type=unigram " \
          f"--pad_id=0 --unk_id=1 --bos_id=2 --eos_id=3"
print("Training SentencePiece...")
spm.SentencePieceTrainer.Train(spm_cmd)

sp = spm.SentencePieceProcessor()
sp.load(f"{spm_prefix}.model")

PAD_ID = sp.pad_id()
UNK_ID = sp.unk_id()
BOS_ID = sp.bos_id()
EOS_ID = sp.eos_id()

PAD_ID, UNK_ID, BOS_ID, EOS_ID

Training SentencePiece...


(0, 1, 2, 3)

In [11]:

def encode_sentence(text: str, add_bos: bool, add_eos: bool) -> List[int]:
    ids = sp.encode(text, out_type=int)
    if add_bos:
        ids = [BOS_ID] + ids
    if add_eos:
        ids = ids + [EOS_ID]
    return ids

def clip(ids: List[int], max_len: int) -> List[int]:
    if len(ids) > max_len:
        return ids[:max_len]
    return ids

class EnJaDataset(Dataset):
    def __init__(self, pairs, max_len=MAX_LEN):
        self.data = []
        for en, ja in pairs:
            src = clip(encode_sentence(en, add_bos=False, add_eos=True), max_len)      # encoder gets EN + <eos>
            tgt = clip(encode_sentence(ja, add_bos=True, add_eos=True), max_len)       # decoder: <bos> JA ... <eos>
            # tgt_inp  = <bos> ... last token before eos
            # tgt_out  = next tokens after bos (i.e., shifted left), includes eos
            tgt_inp = tgt[:-1]
            tgt_out = tgt[1:]
            self.data.append((src, tgt_inp, tgt_out))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def pad_batch(seqs, pad_id=PAD_ID):
    maxlen = max(len(s) for s in seqs)
    return [s + [pad_id]*(maxlen - len(s)) for s in seqs]

def collate(batch):
    srcs, tgts_inp, tgts_out = zip(*batch)
    srcs = pad_batch(srcs, PAD_ID)
    tgts_inp = pad_batch(tgts_inp, PAD_ID)
    tgts_out = pad_batch(tgts_out, PAD_ID)
    srcs = torch.tensor(srcs, dtype=torch.long)
    tgts_inp = torch.tensor(tgts_inp, dtype=torch.long)
    tgts_out = torch.tensor(tgts_out, dtype=torch.long)
    # masks: True where PAD
    src_pad_mask = (srcs == PAD_ID)  # [B, S]
    tgt_pad_mask = (tgts_inp == PAD_ID)  # [B, T]
    return srcs, tgts_inp, tgts_out, src_pad_mask, tgt_pad_mask

train_ds = EnJaDataset(train_pairs, MAX_LEN)
val_ds   = EnJaDataset(val_pairs, MAX_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

len(train_ds), len(val_ds)

(257045, 13528)

In [12]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # x: [B, L, D]
        L = x.size(1)
        return x + self.pe[:, :L, :]

def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    # (T, T)
    return torch.triu(torch.ones(sz, sz, dtype=torch.bool), diagonal=1)

class Seq2SeqTransformer(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, nhead: int, num_encoder_layers: int,
                 num_decoder_layers: int, ff_dim: int, dropout: float):
        super().__init__()
        self.src_tok_emb = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_ID)
        self.tgt_tok_emb = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_ID)
        self.pos_enc = PositionalEncoding(embed_dim)

        self.transformer = nn.Transformer(
            d_model=embed_dim, nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=ff_dim, dropout=dropout,
            batch_first=True  # use [B, L, D]
        )
        self.generator = nn.Linear(embed_dim, vocab_size)

        # initialize
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask):
        # src: [B,S], tgt_inp: [B,T]
        src_emb = self.pos_enc(self.src_tok_emb(src))  # [B,S,D]
        tgt_emb = self.pos_enc(self.tgt_tok_emb(tgt_inp))  # [B,T,D]

        tgt_mask = generate_square_subsequent_mask(tgt_inp.size(1)).to(src.device)  # [T,T]

        out = self.transformer(
            src=src_emb, tgt=tgt_emb,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
            tgt_mask=tgt_mask
        )  # [B,T,D]
        logits = self.generator(out)  # [B,T,V]
        return logits

In [13]:

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model = Seq2SeqTransformer(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    nhead=NHEAD,
    num_encoder_layers=ENC_LAYERS,
    num_decoder_layers=DEC_LAYERS,
    ff_dim=FF_DIM,
    dropout=DROPOUT
).to(DEVICE)

print(f"Model parameters: {count_parameters(model):,}")

optimizer = torch.optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID, label_smoothing=LABEL_SMOOTH)

def train_one_epoch(epoch):
    model.train()
    total_loss = 0.0
    steps = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} [train]", leave=False)
    for src, tgt_inp, tgt_out, src_pad, tgt_pad in pbar:
        src, tgt_inp, tgt_out = src.to(DEVICE), tgt_inp.to(DEVICE), tgt_out.to(DEVICE)
        src_pad, tgt_pad = src_pad.to(DEVICE), tgt_pad.to(DEVICE)

        optimizer.zero_grad(set_to_none=True)
        logits = model(src, tgt_inp, src_pad, tgt_pad)  # [B,T,V]
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        steps += 1
        pbar.set_postfix({"loss": f"{total_loss/steps:.3f}"})
    return total_loss / max(1, steps)

@torch.no_grad()
def evaluate_loss():
    model.eval()
    total_loss = 0.0
    steps = 0
    for src, tgt_inp, tgt_out, src_pad, tgt_pad in val_loader:
        src, tgt_inp, tgt_out = src.to(DEVICE), tgt_inp.to(DEVICE), tgt_out.to(DEVICE)
        src_pad, tgt_pad = src_pad.to(DEVICE), tgt_pad.to(DEVICE)
        logits = model(src, tgt_inp, src_pad, tgt_pad)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        total_loss += loss.item()
        steps += 1
    return total_loss / max(1, steps)

def ids_to_text(ids: List[int]) -> str:
    # Strip padding/BOS/EOS
    out = []
    for i in ids:
        if i in (PAD_ID, BOS_ID, EOS_ID):
            continue
        out.append(i)
    return sp.decode(out)

@torch.no_grad()
def greedy_decode(en_text: str, max_len=MAX_LEN) -> str:
    model.eval()
    # Encode source
    src = [clip(encode_sentence(en_text, add_bos=False, add_eos=True), max_len)]
    src = torch.tensor(src, dtype=torch.long).to(DEVICE)  # [1,S]
    src_pad = (src == PAD_ID)

    # Start target with <bos>
    tgt = torch.tensor([[BOS_ID]], dtype=torch.long).to(DEVICE)  # [1,1]

    for _ in range(max_len-1):
        tgt_pad = (tgt == PAD_ID)
        logits = model(src, tgt, src_pad, tgt_pad)  # [1,T,V]
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)  # [1,1]
        tgt = torch.cat([tgt, next_token], dim=1)  # [1,T+1]
        if next_token.item() == EOS_ID:
            break

    return ids_to_text(tgt.squeeze(0).tolist())

@torch.no_grad()
def compute_bleu(samples=256):
    # Evaluate BLEU on a random subset of the validation set
    if len(val_ds) == 0:
        return 0.0
    idxs = random.sample(range(len(val_ds)), k=min(samples, len(val_ds)))
    refs = []
    hyps = []
    for i in idxs:
        en, ja = val_pairs[i]
        hyp = greedy_decode(en, max_len=MAX_LEN)
        refs.append([ja])
        hyps.append(hyp)
    bleu = sacrebleu.corpus_bleu(hyps, list(zip(*refs)))
    return float(bleu.score)

Model parameters: 54,019,712


In [14]:
best_val = float("inf")
best_bleu = 0.0
ckpt_path = ARTIFACTS_DIR / "enja_transformer.pt"
spm_model_path = ARTIFACTS_DIR / "spm_enja.model"

print("Starting training… ")
for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    tr_loss = train_one_epoch(epoch)
    val_loss = evaluate_loss()
    took = time.time() - t0

    msg = f"Epoch {epoch:02d} | train loss {tr_loss:.3f} | val loss {val_loss:.3f} | {took:.1f}s"

    # Optionally compute BLEU every few epochs
    if epoch % max(1, EPOCHS // 6) == 0 or epoch == EPOCHS:
        bleu = compute_bleu(samples=128)
        msg += f" | BLEU {bleu:.1f}"
        if bleu > best_bleu:
            best_bleu = bleu
            torch.save({
                "model_state": model.state_dict(),
                "config": {
                    "VOCAB_SIZE": VOCAB_SIZE, "EMBED_DIM": EMBED_DIM, "FF_DIM": FF_DIM,
                    "NHEAD": NHEAD, "ENC_LAYERS": ENC_LAYERS, "DEC_LAYERS": DEC_LAYERS,
                    "DROPOUT": DROPOUT, "MAX_LEN": MAX_LEN
                }
            }, ckpt_path)

    print(msg)

print("\nTraining complete.")
print("Best BLEU (val):", best_bleu)
print("Saved best model to:", ckpt_path)
print("Tokenizer model:", spm_model_path)

Starting training… 


  output = torch._nested_tensor_from_mask(


Epoch 01 | train loss 6.094 | val loss 5.741 | 206.1s


                                                                                

Epoch 02 | train loss 5.622 | val loss 5.453 | 209.2s


                                                                                

Epoch 03 | train loss 5.337 | val loss 5.170 | 211.0s | BLEU 0.1


                                                                                

Epoch 04 | train loss 5.121 | val loss 4.986 | 210.8s


                                                                                

Epoch 05 | train loss 4.962 | val loss 4.826 | 212.2s


                                                                                

Epoch 06 | train loss 4.841 | val loss 4.750 | 212.3s | BLEU 2.6


                                                                                

Epoch 07 | train loss 4.728 | val loss 4.623 | 209.0s


                                                                                

Epoch 08 | train loss 4.623 | val loss 4.513 | 208.3s


                                                                                

Epoch 09 | train loss 4.534 | val loss 4.431 | 212.3s | BLEU 1.2


                                                                                 

Epoch 10 | train loss 4.463 | val loss 4.384 | 208.9s


                                                                                 

Epoch 11 | train loss 4.404 | val loss 4.336 | 208.0s


                                                                                 

Epoch 12 | train loss 4.356 | val loss 4.297 | 209.1s | BLEU 4.2


                                                                                 

Epoch 13 | train loss 4.307 | val loss 4.270 | 209.6s


                                                                                 

Epoch 14 | train loss 4.271 | val loss 4.224 | 209.7s


                                                                                 

Epoch 15 | train loss 4.233 | val loss 4.190 | 213.3s | BLEU 5.0


                                                                                 

Epoch 16 | train loss 4.205 | val loss 4.180 | 213.3s


                                                                                 

Epoch 17 | train loss 4.173 | val loss 4.144 | 213.4s


                                                                                 

Epoch 18 | train loss 4.149 | val loss 4.122 | 213.1s | BLEU 0.6


                                                                                 

Epoch 19 | train loss 4.127 | val loss 4.131 | 213.4s


                                                                                 

Epoch 20 | train loss 4.105 | val loss 4.088 | 213.4s | BLEU 1.8

Training complete.
Best BLEU (val): 4.955349584429408
Saved best model to: artifacts\enja_transformer.pt
Tokenizer model: artifacts\spm_enja.model


In [None]:

# Training results:
examples = [
    "hello, how are you?",
    "what time is the last train?",
    "my dog likes to bark.",
    "what did you eat for breakfast?"
]

for s in examples:
    print("\nEN:", s)
    print("JA:", greedy_decode(s))



EN: hello, how are you?
JA: 元気?

EN: what time is the last train?
JA: 電車はどのくらいですか。

EN: my dog likes to bark.
JA: 犬が吠えで吠えそう。

EN: what did you eat for breakfast?
JA: 朝食は何を食べましたか。


In [16]:
# To save training artifacts
final_ckpt = ARTIFACTS_DIR / "enja_transformer_final.pt"
final_spm = ARTIFACTS_DIR / "spm_enja.model"

# To save model checkpoint (weights + config)
torch.save({
    "model_state": model.state_dict(),
    "config": {
        "VOCAB_SIZE": VOCAB_SIZE, "EMBED_DIM": EMBED_DIM, "FF_DIM": FF_DIM,
        "NHEAD": NHEAD, "ENC_LAYERS": ENC_LAYERS, "DEC_LAYERS": DEC_LAYERS,
        "DROPOUT": DROPOUT, "MAX_LEN": MAX_LEN
    }
}, final_ckpt)

print("Saved final model to:", final_ckpt)
print("Tokenizer model already at:", final_spm)

Saved final model to: artifacts\enja_transformer_final.pt
Tokenizer model already at: artifacts\spm_enja.model


In [None]:
# Load SentencePiece tokenizer
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load(str(ARTIFACTS_DIR / "spm_enja.model"))

# Reload model
ckpt = torch.load(ARTIFACTS_DIR / "enja_transformer_final.pt", map_location=DEVICE)
cfg = ckpt["config"]

model = Seq2SeqTransformer(
    vocab_size=cfg["VOCAB_SIZE"],
    embed_dim=cfg["EMBED_DIM"],
    nhead=cfg["NHEAD"],
    num_encoder_layers=cfg["ENC_LAYERS"],
    num_decoder_layers=cfg["DEC_LAYERS"],
    ff_dim=cfg["FF_DIM"],
    dropout=cfg["DROPOUT"]
).to(DEVICE)

model.load_state_dict(ckpt["model_state"])
model.eval()

print("Model + tokenizer loaded successfully!")