# Phoneme seq2seq with CMUdict (encoder-decoder)
Train a Transformer encoder-decoder with cross-entropy to map noisy phoneme sequences back to clean phoneme sequences using LibriSpeech text.


In [13]:
from pathlib import Path
import math
import random
import re
import pickle
from typing import Dict, List, Optional, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler


# Rebuild LibriSpeech sentences
Use LibriSpeech text in notebooks/data to regenerate librispeech_sentences.pkl into notebooks/data and data/ before loading.


In [14]:
from pathlib import Path
import re
import pickle

ROOT = Path.cwd()
RAW_DATA = ROOT / "data"
BOOKS_DIR = RAW_DATA / "LibriSpeech" / "books" / "ascii"
OUTPUT_PATHS = [
    RAW_DATA / "librispeech_sentences.pkl",
    ROOT / "data" / "librispeech_sentences.pkl",
]

SENTENCE_RE = re.compile(r"[A-Za-z][^.!?]*[.!?]")
MIN_WORDS, MAX_WORDS = 3, 50

if not BOOKS_DIR.exists():
    raise FileNotFoundError(f"Missing LibriSpeech text at {BOOKS_DIR}")


def extract_sentences(text: str):
    for match in SENTENCE_RE.finditer(text):
        sentence = match.group().strip()
        length = len(sentence.split())
        if MIN_WORDS <= length <= MAX_WORDS:
            yield sentence


sentences = []
text_files = sorted(BOOKS_DIR.rglob("*.txt"))
print(f"Scanning {len(text_files)} text files under {BOOKS_DIR} ...")
for txt in text_files:
    text = txt.read_text(encoding="utf-8", errors="ignore").replace("\n", " ")
    sentences.extend(extract_sentences(text))

sentences = list(dict.fromkeys(sentences))  # preserve order while deduplicating
print(f"Collected {len(sentences):,} sentences after filtering + deduplication.")

for out in OUTPUT_PATHS:
    out.parent.mkdir(parents=True, exist_ok=True)
    with out.open("wb") as f:
        pickle.dump(sentences, f)
    print(f"Saved sentences to {out}")

ROOT / "data" / "librispeech_sentences.pkl"


Scanning 1443 text files under c:\Users\johnn\Desktop\EC ENGR C143A\c143a-project\notebooks\data\LibriSpeech\books\ascii ...
Collected 6,078,025 sentences after filtering + deduplication.
Collected 6,078,025 sentences after filtering + deduplication.
Saved sentences to c:\Users\johnn\Desktop\EC ENGR C143A\c143a-project\notebooks\data\librispeech_sentences.pkl
Saved sentences to c:\Users\johnn\Desktop\EC ENGR C143A\c143a-project\notebooks\data\librispeech_sentences.pkl
Saved sentences to c:\Users\johnn\Desktop\EC ENGR C143A\c143a-project\notebooks\data\librispeech_sentences.pkl
Saved sentences to c:\Users\johnn\Desktop\EC ENGR C143A\c143a-project\notebooks\data\librispeech_sentences.pkl


WindowsPath('c:/Users/johnn/Desktop/EC ENGR C143A/c143a-project/notebooks/data/librispeech_sentences.pkl')

In [15]:
# Hyperparameters / config collected here for quick tuning
DATA_DIR = Path("data")
CMU_PATH = DATA_DIR / "cmudict-0.7b"
SENTENCE_PKL = DATA_DIR / "librispeech_sentences.pkl"

MAX_SENTENCES = 6_000_000
MAX_LEN = 128
TRAIN_SAMPLES = 4_000_000
VAL_SAMPLES = 10_000
BATCH_SIZE = 1024
NUM_EPOCHS = 10
SEED = 0

NOISE = (0.1, 0.1, 0.1)  # (p_sub, p_ins, p_del)
D_MODEL = 256
N_HEAD = 8
BIDIRECTIONAL = False  # False = streaming (causal); True = limited lookahead (current + next phoneme)
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4
DIM_FEEDFORWARD = 1024
DROPOUT = 0.2

BASE_LR = 1e-3
END_LR = 1e-4
WEIGHT_DECAY = 1e-4
WARMUP_PCT = 0.1
GRAD_ACCUM_STEPS = 1
MIXED_PRECISION = True
USE_TF32 = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rng = random.Random(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x1f13221ddd0>

In [16]:
# Load CMUdict and LibriSpeech sentences
if not CMU_PATH.exists():
    raise FileNotFoundError(f"Missing CMUdict at {CMU_PATH}")
if not SENTENCE_PKL.exists():
    raise FileNotFoundError(f"Missing LibriSpeech sentences at {SENTENCE_PKL}")

stress_digits = "0123456789"
cmu_map: Dict[str, List[List[str]]] = {}
with CMU_PATH.open("r", encoding="utf-8", errors="ignore") as f:
    for line in f:
        line = line.strip()
        if not line or line.startswith(";;;"):
            continue
        parts = line.split()
        word, phones = parts[0], parts[1:]
        clean = [ph.translate({ord(d): None for d in stress_digits}) for ph in phones]
        cmu_map.setdefault(word.lower(), []).append(clean)

with SENTENCE_PKL.open("rb") as f:
    raw_sentences: List[str] = pickle.load(f)

len(cmu_map), len(raw_sentences)


(134373, 6078025)

In [17]:
# Build phoneme vocabulary (pad=0, sos=1, eos=2)
ALL_PHONEMES = sorted(
    set(ph for word_prons in cmu_map.values() for phoneme_seq in word_prons for ph in phoneme_seq)
)
if "<sil>" not in ALL_PHONEMES:
    ALL_PHONEMES.append("<sil>")
else:
    ALL_PHONEMES = [ph for ph in ALL_PHONEMES if ph != "<sil>"] + ["<sil>"]
PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2
phoneme_to_idx: Dict[str, int] = {ph: i + 3 for i, ph in enumerate(ALL_PHONEMES)}
phoneme_to_idx["<pad>"] = PAD_IDX
phoneme_to_idx["<sos>"] = SOS_IDX
phoneme_to_idx["<eos>"] = EOS_IDX
idx_to_phoneme = {i: p for p, i in phoneme_to_idx.items()}

print(f"Phoneme vocab ({len(ALL_PHONEMES)}): {', '.join(ALL_PHONEMES)}")

vocab_size = len(phoneme_to_idx)

CONFUSABLE_GROUPS = [
    ["DH", "TH"],
    ["IH", "IY"],
    ["AH", "AA", "AE"],
    ["EH", "AE"],
    ["S", "Z"],
    ["SH", "ZH"],
    ["P", "B"],
    ["T", "D"],
    ["K", "G"],
    ["F", "V"],
    ["CH", "JH"],
]
neighbor_map: Dict[int, List[int]] = {}
for group in CONFUSABLE_GROUPS:
    ids = [phoneme_to_idx[p] for p in group if p in phoneme_to_idx]
    for pid in ids:
        neighbor_map[pid] = [q for q in ids if q != pid]

vocab_size, list(phoneme_to_idx.items())[:5]


Phoneme vocab (40): AA, AE, AH, AO, AW, AY, B, CH, D, DH, EH, ER, EY, F, G, HH, IH, IY, JH, K, L, M, N, NG, OW, OY, P, R, S, SH, T, TH, UH, UW, V, W, Y, Z, ZH, <sil>


(43, [('AA', 3), ('AE', 4), ('AH', 5), ('AO', 6), ('AW', 7)])

In [18]:
WORD_RE = re.compile(r"[A-Za-z']+")

def sentence_to_phonemes(text: str) -> Optional[List[str]]:
    words = WORD_RE.findall(text.lower())
    if not words:
        return None
    phonemes: List[str] = []
    for i, w in enumerate(words):
        prons = cmu_map.get(w)
        if not prons:
            return None
        phones = rng.choice(prons)
        phonemes.extend(phones)
        if i < len(words) - 1:
            phonemes.append("<sil>")
    return phonemes

if MAX_SENTENCES and len(raw_sentences) > MAX_SENTENCES:
    sampled_sentences = rng.sample(raw_sentences, MAX_SENTENCES)
else:
    sampled_sentences = list(raw_sentences)

sentence_records: List[Dict[str, object]] = []
for sent in sampled_sentences:
    phonemes = sentence_to_phonemes(sent)
    if phonemes:
        sentence_records.append({"text": sent, "phonemes": phonemes})

if not sentence_records:
    raise RuntimeError("No sentences could be converted with CMUdict coverage.")
print(f"Prepared {len(sentence_records):,} sentence records (from {len(sampled_sentences):,} samples).")


Prepared 4,315,739 sentence records (from 6,000,000 samples).


In [19]:
def add_noise(
    seq: List[int],
    vocab_size: int,
    p_sub: float = 0.1,
    p_ins: float = 0.05,
    p_del: float = 0.05,
) -> List[int]:
    noisy: List[int] = []
    for token in seq:
        if random.random() < p_del:
            continue
        if random.random() < p_sub:
            neighbors = neighbor_map.get(token, [])
            if neighbors and random.random() < 0.8:
                token = random.choice(neighbors)
            else:
                token = random.randint(3, vocab_size - 1)
        noisy.append(token)
        if random.random() < p_ins:
            noisy.append(random.randint(3, vocab_size - 1))
    return noisy


def encode_phonemes(phonemes: List[str], max_len: int) -> List[int]:
    tokens = [phoneme_to_idx[p] for p in phonemes if p in phoneme_to_idx]
    tokens = tokens[: max_len - 2]
    tokens.append(EOS_IDX)
    return tokens


def pad_sequence(seq: List[int], max_len: int, pad_value: int = PAD_IDX) -> List[int]:
    return seq + [pad_value] * (max_len - len(seq))


class PhonemeCorrectionDataset(Dataset):
    def __init__(
        self,
        records: List[Dict[str, object]],
        max_len: int = MAX_LEN,
        num_samples: Optional[int] = None,
        noise: Tuple[float, float, float] = NOISE,
    ) -> None:
        self.records = records
        self.max_len = max_len
        self.num_samples = num_samples or len(records)
        self.noise = noise

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int):
        clean_ph: List[str] = self.records[idx % len(self.records)]["phonemes"]  # type: ignore[index]
        tgt_tokens = encode_phonemes(clean_ph, self.max_len)
        noisy_tokens = add_noise(
            tgt_tokens,
            vocab_size,
            p_sub=self.noise[0],
            p_ins=self.noise[1],
            p_del=self.noise[2],
        )
        noisy_tokens = noisy_tokens[: self.max_len]
        tgt_tokens = tgt_tokens[: self.max_len]
        return torch.tensor(noisy_tokens), torch.tensor(tgt_tokens)


def collate_batch(batch):
    srcs, tgts = zip(*batch)
    max_src = min(max(len(s) for s in srcs), MAX_LEN)
    max_tgt = min(max(len(t) for t in tgts) + 2, MAX_LEN)  # room for <sos>/<eos>

    padded_src = [pad_sequence(s.tolist()[:max_src], max_src) for s in srcs]
    tgt_in, tgt_out = [], []
    for tgt in tgts:
        tokens = tgt.tolist()[: max_tgt - 2]
        seq = [SOS_IDX] + tokens + [EOS_IDX]
        seq = seq[:max_tgt]
        seq += [PAD_IDX] * (max_tgt - len(seq))
        tgt_in.append(seq[:-1])
        tgt_out.append(seq[1:])

    return torch.tensor(padded_src), torch.tensor(tgt_in), torch.tensor(tgt_out)


# quick sanity check
_ds = PhonemeCorrectionDataset(sentence_records, max_len=32, num_samples=5)
for noisy, clean in _ds:
    print("noisy: ", [idx_to_phoneme[i] for i in noisy.tolist() if i != PAD_IDX])
    print("clean: ", [idx_to_phoneme[i] for i in clean.tolist() if i != PAD_IDX])
    break


noisy:  ['AA', 'P', 'D', 'EY', 'NG', '<sil>', 'D', 'EH', 'S', '<sil>', 'AA', 'R', 'T', 'S', '<eos>']
clean:  ['AA', 'N', 'D', 'R', 'EY', '<sil>', 'D', 'EH', 'S', '<sil>', 'AA', 'R', 'T', 'S', '<eos>']


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000) -> None:
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, : x.size(1)].to(x.device)


def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
    mask = torch.full((sz, sz), float('-inf'), device=device)
    mask = torch.triu(mask, diagonal=1)
    mask.fill_diagonal_(0.0)
    return mask


def generate_limited_future_mask(sz: int, device: torch.device, lookahead: int = 1) -> torch.Tensor:
    # Allow attention to current position and up to `lookahead` future steps; disallow further lookahead.
    i = torch.arange(sz, device=device).unsqueeze(1)
    j = torch.arange(sz, device=device)S
    return torch.where(j - i > lookahead, float('-inf'), 0.0)


class Seq2SeqTransformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = D_MODEL,
        nhead: int = N_HEAD,
        num_encoder_layers: int = NUM_ENCODER_LAYERS,
        num_decoder_layers: int = NUM_DECODER_LAYERS,
        dim_feedforward: int = DIM_FEEDFORWARD,
        dropout: float = DROPOUT,
        bidirectional: bool = BIDIRECTIONAL,
    ) -> None:
        super().__init__()
        self.bidirectional = bidirectional
        self.src_emb = nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
        self.tgt_emb = nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
        self.pos_enc = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        self.proj = nn.Linear(d_model, vocab_size)

    def encode(self, src: torch.Tensor) -> torch.Tensor:
        src_key_padding_mask = src == PAD_IDX
        if self.bidirectional:
            src_mask = generate_limited_future_mask(src.size(1), src.device, lookahead=1)
        else:
            src_mask = generate_square_subsequent_mask(src.size(1), src.device)
        return self.encoder(
            self.pos_enc(self.src_emb(src)),
            mask=src_mask,
            src_key_padding_mask=src_key_padding_mask,
        )

    def forward(self, src: torch.Tensor, tgt_in: torch.Tensor) -> torch.Tensor:
        src_key_padding_mask = src == PAD_IDX
        tgt_key_padding_mask = tgt_in == PAD_IDX
        tgt_mask = generate_square_subsequent_mask(tgt_in.size(1), tgt_in.device)

        src_h = self.encode(src)
        tgt_h = self.decoder(
            self.pos_enc(self.tgt_emb(tgt_in)),
            src_h,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
        )
        return self.proj(tgt_h)


In [21]:
# Training setup
train_samples = min(TRAIN_SAMPLES, len(sentence_records))
val_samples = min(VAL_SAMPLES, max(0, len(sentence_records) - train_samples))

# explicit train/val split with deterministic shuffle
split_records = list(sentence_records)
rng.shuffle(split_records)
train_records = split_records[:train_samples]
val_records = split_records[train_samples: train_samples + val_samples] or split_records[-val_samples:]

train_ds = PhonemeCorrectionDataset(train_records, max_len=MAX_LEN, num_samples=len(train_records), noise=NOISE)
val_ds = PhonemeCorrectionDataset(val_records, max_len=MAX_LEN, num_samples=len(val_records), noise=NOISE)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
model = Seq2SeqTransformer(
    vocab_size=vocab_size,
    d_model=D_MODEL,
    nhead=N_HEAD,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    dim_feedforward=DIM_FEEDFORWARD,
    dropout=DROPOUT,
    bidirectional=BIDIRECTIONAL,
)

model = Seq2SeqTransformer(
    vocab_size=vocab_size,
    d_model=D_MODEL,
    nhead=N_HEAD,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    dim_feedforward=DIM_FEEDFORWARD,
    dropout=DROPOUT,
)
model = model.to(DEVICE)

if DEVICE.type == "cuda" and USE_TF32:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)
total_train_iters = max(1, NUM_EPOCHS * len(train_loader))
warmup_iters = max(1, int(WARMUP_PCT * total_train_iters))
min_lr_factor = END_LR / BASE_LR


def lr_lambda(step: int) -> float:
    if step < warmup_iters:
        return (step + 1) / warmup_iters
    remaining = max(total_train_iters - warmup_iters, 1)
    decay_step = step - warmup_iters
    decay_frac = max(0.0, 1.0 - decay_step / remaining)
    return max(min_lr_factor, decay_frac)


scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
scaler = GradScaler(enabled=MIXED_PRECISION and DEVICE.type == "cuda")


  scaler = GradScaler(enabled=MIXED_PRECISION and DEVICE.type == "cuda")


In [22]:
def tokens_to_phonemes(tokens: List[int]) -> List[str]:
    return [idx_to_phoneme.get(int(t), f"<{t}>") for t in tokens if int(t) not in (PAD_IDX, SOS_IDX, EOS_IDX)]


def greedy_decode(src: torch.Tensor, max_len: int = MAX_LEN) -> List[int]:
    model.eval()
    src = src.unsqueeze(0).to(DEVICE)
    src_key_padding_mask = src == PAD_IDX
    memory = model.encode(src)
    ys = torch.tensor([[SOS_IDX]], device=DEVICE)
    for _ in range(max_len):
        tgt_mask = generate_square_subsequent_mask(ys.size(1), ys.device)
        out = model.decoder(
            model.pos_enc(model.tgt_emb(ys)),
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=None,
            memory_key_padding_mask=src_key_padding_mask,
        )
        logits = model.proj(out[:, -1, :])
        next_word = logits.argmax(dim=-1).unsqueeze(0)
        ys = torch.cat([ys, next_word], dim=1)
        if next_word.item() == EOS_IDX:
            break
    return ys.squeeze(0).tolist()


def log_val_example(example_idx: int = 0) -> None:
    model.eval()
    noisy, tgt = val_ds[example_idx % len(val_ds)]
    rec = val_ds.records[example_idx % len(val_ds)]
    text = rec.get('text', '') if isinstance(rec, dict) else ''
    pred_tokens = greedy_decode(noisy, max_len=MAX_LEN)

    noisy_ph = tokens_to_phonemes(noisy.tolist())
    tgt_ph = tokens_to_phonemes(tgt.tolist())
    pred_ph = tokens_to_phonemes(pred_tokens)

    print(f"[val example] noisy:        {' '.join(noisy_ph)}")
    print(f"[val example] target:       {' '.join(tgt_ph)}")
    print(f"[val example] corrected:    {' '.join(pred_ph) if pred_ph else '<empty>'}")
    if text:
        print(f"[val example] plain text:   {text}")


def run_epoch(loader, train: bool = True):
    model.train(train)
    total_loss = 0.0
    accum = GRAD_ACCUM_STEPS if train else 1
    if train:
        optimizer.zero_grad(set_to_none=True)
    num_batches = len(loader)

    for step, (src, tgt_in, tgt_out) in enumerate(loader):
        src, tgt_in, tgt_out = src.to(DEVICE), tgt_in.to(DEVICE), tgt_out.to(DEVICE)

        # autocast may not accept a device_type kwarg on all torch versions;
        # use enabled=scaler.is_enabled() which already reflects MIXED_PRECISION & CUDA
        with autocast(enabled=scaler.is_enabled()):
            logits = model(src, tgt_in)
            loss = criterion(logits.reshape(-1, vocab_size), tgt_out.reshape(-1))

        total_loss += loss.item()

        if train:
            loss = loss / accum
            scaler.scale(loss).backward()
            do_step = (step + 1) % accum == 0 or (step + 1) == num_batches
            if do_step:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()
    return total_loss / len(loader)


In [None]:
for epoch in range(1, NUM_EPOCHS):
    train_loss = run_epoch(train_loader, train=True)
    val_loss = run_epoch(val_loader, train=False)
    print(f"epoch {epoch}: train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | lr={scheduler.get_last_lr()[0]:.6f}")
    log_val_example(example_idx=epoch)

    # checkpoint every epoch
    ckpt_path = Path("checkpoints") / f"phoneme_seq2seq_epoch_{epoch+1}.pt"
    ckpt_path.parent.mkdir(parents=True, exist_ok=True)

    torch.save(
        {
            "state_dict": model.state_dict(),
            "phoneme_to_idx": phoneme_to_idx,
            "idx_to_phoneme": idx_to_phoneme,
            "pad_idx": PAD_IDX,
            "sos_idx": SOS_IDX,
            "eos_idx": EOS_IDX,
            "max_len": MAX_LEN,
            "model_kwargs": {
                "vocab_size": vocab_size,
                "d_model": D_MODEL,
                "nhead": N_HEAD,
                "num_encoder_layers": NUM_ENCODER_LAYERS,
                "num_decoder_layers": NUM_DECODER_LAYERS,
                "dim_feedforward": DIM_FEEDFORWARD,
                "dropout": DROPOUT,
            },
        },
        ckpt_path,
    )
    print(f"Saved epoch {epoch+1} checkpoint to {ckpt_path.resolve()}")

In [24]:
for epoch in range(5, NUM_EPOCHS):
    train_loss = run_epoch(train_loader, train=True)
    val_loss = run_epoch(val_loader, train=False)
    print(f"epoch {epoch}: train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | lr={scheduler.get_last_lr()[0]:.6f}")
    log_val_example(example_idx=epoch)

    # checkpoint every epoch
    ckpt_path = Path("checkpoints") / f"phoneme_seq2seq_epoch_{epoch+1}.pt"
    ckpt_path.parent.mkdir(parents=True, exist_ok=True)

    torch.save(
        {
            "state_dict": model.state_dict(),
            "phoneme_to_idx": phoneme_to_idx,
            "idx_to_phoneme": idx_to_phoneme,
            "pad_idx": PAD_IDX,
            "sos_idx": SOS_IDX,
            "eos_idx": EOS_IDX,
            "max_len": MAX_LEN,
            "model_kwargs": {
                "vocab_size": vocab_size,
                "d_model": D_MODEL,
                "nhead": N_HEAD,
                "num_encoder_layers": NUM_ENCODER_LAYERS,
                "num_decoder_layers": NUM_DECODER_LAYERS,
                "dim_feedforward": DIM_FEEDFORWARD,
                "dropout": DROPOUT,
            },
        },
        ckpt_path,
    )
    print(f"Saved epoch {epoch+1} checkpoint to {ckpt_path.resolve()}")

  with autocast(enabled=scaler.is_enabled()):


epoch 5: train_loss=0.2997 | val_loss=0.2427 | lr=0.001000
[val example] noisy:        SH UW <sil> UW <sil> D N T <sil> S EY UH S M OW
[val example] target:       SH UW <sil> Y UW <sil> D OW N T <sil> S EY <sil> S OW
[val example] corrected:    Y UW <sil> D UW <sil> N AA T <sil> S EY <sil> S OW
[val example] plain text:   Shoo--you don't say so!
Saved epoch 6 checkpoint to C:\Users\johnn\Desktop\EC ENGR C143A\c143a-project\notebooks\checkpoints\phoneme_seq2seq_epoch_6.pt
epoch 6: train_loss=0.2913 | val_loss=0.2358 | lr=0.000889
[val example] noisy:        DH AH <sil> S AY L AH N S <sil> W AA K Z <sil> B IH K AH M IH NG <sil> AH N B EH JH B AH L <sil> SH IY OY <sil> IH S T R AH G AH L D <sil> P UW <sil> TH IH NG K <sil> AH V <sil> S AH M TH NG <sil> T UW <sil> S EY <sil> B AH T <sil> AH IH NG <sil> K M <sil> AH N D <sil> IY <sil> R OW Z <sil> AH AH P L IY
[val example] target:       DH AH <sil> S AY L AH N S <sil> W AA Z <sil> B IH K AH M IH NG <sil> AH N B EH R AH B AH L <sil> SH IY <

In [23]:
# load checkpoint 5
ckpt_path = Path("checkpoints") / "phoneme_seq2seq_epoch_5.pt"
checkpoint = torch.load(ckpt_path, map_location=DEVICE)
model.load_state_dict(checkpoint["state_dict"])
BATCH_SIZE = 1024
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

In [None]:
# Decode a few examples (greedy autoregressive)
model.eval()
with torch.no_grad():
    src_batch, tgt_batch, _ = next(iter(val_loader))
    src_batch, tgt_batch = src_batch.to(DEVICE), tgt_batch.to(DEVICE)

for i in range(3):
    noisy_tokens = [idx_to_phoneme[t.item()] for t in src_batch[i] if t.item() != PAD_IDX]
    clean_tokens = [idx_to_phoneme[t.item()] for t in tgt_batch[i] if t.item() not in (PAD_IDX, SOS_IDX, EOS_IDX)]
    pred_tokens = greedy_decode(src_batch[i].cpu(), max_len=MAX_LEN)
    pred = [idx_to_phoneme.get(int(t), f"<{t}>") for t in pred_tokens if int(t) not in (PAD_IDX, SOS_IDX, EOS_IDX)]
    print(f"Noisy    : {' '.join(noisy_tokens)}")
    print(f"Target   : {' '.join(clean_tokens)}")
    print(f"Pred     : {' '.join(pred)}")


In [None]:
export_path = Path("/phoneme_seq2seq.pt")
export_path.parent.mkdir(parents=True, exist_ok=True)

checkpoint = {
    "state_dict": model.state_dict(),
    "phoneme_to_idx": phoneme_to_idx,
    "idx_to_phoneme": idx_to_phoneme,
    "pad_idx": PAD_IDX,
    "sos_idx": SOS_IDX,
    "eos_idx": EOS_IDX,
    "max_len": MAX_LEN,
    "model_kwargs": {
        "vocab_size": vocab_size,
        "d_model": D_MODEL,
        "nhead": N_HEAD,
        "num_encoder_layers": NUM_ENCODER_LAYERS,
        "num_decoder_layers": NUM_DECODER_LAYERS,
        "dim_feedforward": DIM_FEEDFORWARD,
        "dropout": DROPOUT,
    },
}
torch.save(checkpoint, export_path)
print(f"Saved phoneme seq2seq checkpoint to {export_path.resolve()}")
