In [1]:
!pip install torchcodec jiwer
!sudo apt-get update && sudo apt-get install sox -y

Hit:1 https://cli.github.com/packages stable InRelease
Hit:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:4 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Fetched 129 kB in 2s (58.5 kB/s)
Reading package lists... Done
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (so

In [2]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Chunkformer

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Chunkformer


In [None]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `hf auth whoami` to get more information or `hf auth logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) y
Token is valid (permission: write).
The token `as

In [4]:
import os
import sys
import math
import argparse
import shutil
from typing import List, Tuple, Optional

import torch
import torchaudio
import yaml
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import sentencepiece as spm

# Add the current directory to the Python path
sys.path.append('.')

from model.utils.init_model import init_model
from model.utils.checkpoint import load_checkpoint
from model.fixed_tokenizer import TextTokenizer

from train import (
    TextTokenizer,
    AudioDataset,
    collate_fn,
    NoamLR,
    build_or_load_tokenizer,
    build_config_yaml,
    save_vocab_and_config,
    init_model,
)

In [5]:
class AudioDataset(Dataset):
    def __init__(self, tsv_path: str, tokenizer: TextTokenizer, use_speed_perturb: bool = True,
                 return_texts: bool = False):
        df = pd.read_csv(tsv_path, sep="\t")
        assert "wav" in df.columns and "txt" in df.columns, "TSV must contain 'wav' and 'txt' columns"
        self.paths = df["wav"].tolist()
        self.texts = df["txt"].tolist()
        self.use_speed_perturb = use_speed_perturb
        self.tokenizer = tokenizer
        self.return_texts = return_texts

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx: int):
        path = self.paths[idx]
        text = str(self.texts[idx])
        waveform, sr = torchaudio.load(path)
        if waveform.size(0) > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        if self.use_speed_perturb:
            waveform = maybe_speed_perturb(waveform, sr)

        feats = compute_fbank(waveform, sr)
        token_ids = self.tokenizer.encode(text)
        if self.return_texts:
            return feats, torch.tensor(token_ids, dtype=torch.long), text
        return feats, torch.tensor(token_ids, dtype=torch.long)

In [6]:
def collate_fn(batch, apply_specaug: bool = True) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
    xs = []
    ys = []
    for sample in batch:
        feats, tgt = sample[0], sample[1]
        if apply_specaug:
            feats = spec_augment(feats)
        xs.append(feats)
        ys.append(tgt)

    xs_lens = torch.tensor([x.size(0) for x in xs], dtype=torch.long)
    ys_cat = torch.cat([y for y in ys]) if len(ys) > 0 else torch.tensor([], dtype=torch.long)
    ys_lens = torch.tensor([len(y) for y in ys], dtype=torch.long)

    return xs, xs_lens, ys_cat, ys_lens

In [7]:
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
    """Noam scheduler with explicit peak learning rate at warmup_steps.

    lr(step) = scale * min(step^-0.5, step * warmup^-1.5)
    where scale is chosen so lr(warmup_steps) == peak_lr.
    """

    def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int, peak_lr: float, last_epoch: int = -1):
        self.warmup_steps = max(1, warmup_steps)
        self.peak_lr = peak_lr
        # Determine scale so that lr at warmup equals peak_lr
        self.scale = peak_lr * (self.warmup_steps ** 0.5)
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        step = max(1, self.last_epoch + 1)
        factor = min(step ** -0.5, step * (self.warmup_steps ** -1.5))
        lr = self.scale * factor
        return [lr for _ in self.base_lrs]

In [8]:
def build_or_load_tokenizer(train_tsv: str, out_dir: str, vocab_size: int) -> TextTokenizer:
    spm_path = os.path.join(out_dir, 'spm.model')
    if os.path.exists(spm_path):
        return TextTokenizer(spm_path)
    texts = pd.read_csv(train_tsv, sep='\t')['txt'].astype(str).tolist()
    spm_path = TextTokenizer.train_from_corpus(texts, out_dir, vocab_size)
    return TextTokenizer(spm_path)

In [9]:
def build_config_yaml(output_dir: str, vocab_size: int, d_model: int, num_blocks: int,
                      attention_heads: int, linear_units: int, dropout_rate: float) -> str:
    config = {
        "cmvn_file": None,
        "is_json_cmvn": False,
        "input_dim": 80,
        "output_dim": vocab_size,
        "encoder_conf": {
            "output_size": d_model,
            "attention_heads": attention_heads,
            "linear_units": linear_units,
            "num_blocks": num_blocks,
            "dropout_rate": dropout_rate,
            "positional_dropout_rate": dropout_rate,
            "attention_dropout_rate": 0.0,
            "input_layer": "conv2d",
            "pos_enc_layer_type": "abs_pos",
            "normalize_before": True,
            "static_chunk_size": 0,
            "use_dynamic_chunk": False,
            "positionwise_conv_kernel_size": 1,
            "macaron_style": True,
            "selfattention_layer_type": "rel_selfattn",
            "activation_type": "swish",
            "use_cnn_module": True,
            "cnn_module_kernel": 15,
            "causal": False,
            "cnn_module_norm": "batch_norm",
            "use_limited_chunk": False,
            "limited_decoding_chunk_sizes": [],
            "limited_left_chunk_sizes": [],
            "use_dynamic_conv": False,
            "use_context_hint_chunk": False,
            "right_context_sizes": [],
            "right_context_probs": [],
            "freeze_subsampling_layer": False,
        }
    }
    config_path = os.path.join(output_dir, "config.yaml")
    with open(config_path, "w", encoding="utf-8") as f:
        yaml.safe_dump(config, f, sort_keys=False)
    return config_path

In [10]:
import os
import shutil
from sentencepiece import SentencePieceProcessor

def save_vocab_and_config(tokenizer: TextTokenizer, output_dir: str, config_path: str):
    os.makedirs(output_dir, exist_ok=True)
    vocab_txt = os.path.join(output_dir, "vocab.txt")
    tokenizer.save_vocab_txt(vocab_txt)
    # Save SentencePiece model next to vocab for consistent inference
    spm_path = None
    try:
        spm_path = tokenizer.sp.model_file()
    except Exception:
        spm_path = None
    if spm_path and isinstance(spm_path, str) and os.path.exists(spm_path):
        shutil.copy2(spm_path, os.path.join(output_dir, "spm.model"))
    else:
        # Fallback: write serialized model bytes if available
        try:
            blob = tokenizer.sp.serialized_model_proto()
        except Exception:
            blob = None
        if blob:
            with open(os.path.join(output_dir, "spm.model"), "wb") as f:
                f.write(blob)
    return vocab_txt

In [11]:
def set_seed(seed: int):
    import random
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [12]:
import os
import argparse
import yaml
import random
import torch
from torch.utils.data import DataLoader

In [13]:
def sample_dynamic_context():
    # Dynamic limited-context policy: randomize chunk and contexts
    # Example ranges adapted from paper spirit; adjust as needed
    chunk_size = random.choice([32, 48])
    left_ctx = random.choice([64, 128])
    right_ctx = random.choice([32, 64])
    return chunk_size, left_ctx, right_ctx

In [14]:
def train_one_epoch_chunk(model, dataloader, optimizer, scheduler, device, scaler,
                          grad_accum_steps: int = 1, max_grad_norm: float = 5.0):
    model.train()
    total_loss = 0.0
    total_tokens = 0
    optimizer.zero_grad(set_to_none=True)

    for step, (xs, xs_lens, ys_cat, ys_lens) in enumerate(dataloader):
        xs = [x.to(device) for x in xs]
        ys_cat = ys_cat.to(device)
        ys_lens = ys_lens.to(device)

        xs_origin_lens = xs_lens.to(device)
        c, l, r = sample_dynamic_context()

        with torch.cuda.amp.autocast(enabled=scaler is not None):
            offset = torch.zeros(len(xs), dtype=torch.int, device=device)
            encoder_outs, encoder_lens, n_chunks, _, _, _ = model.encoder.forward_parallel_chunk(
                xs=xs,
                xs_origin_lens=xs_origin_lens,
                chunk_size=c,
                left_context_size=l,
                right_context_size=r,
                offset=offset,
            )
            enc_padded, enc_masks = model.encoder.rearrange(encoder_outs, xs_origin_lens, n_chunks)
            input_lengths = enc_masks.squeeze(1).sum(dim=1).to(torch.int)
            log_probs = model.ctc.log_softmax(enc_padded).transpose(0, 1)

            loss = torch.nn.functional.ctc_loss(
                log_probs,
                ys_cat,
                input_lengths,
                ys_lens,
                blank=0,
                reduction="sum",
                zero_infinity=True,
            )
            norm = ys_lens.sum().clamp_min(1)
            loss = loss / norm

        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if (step + 1) % grad_accum_steps == 0:
            if scaler is not None:
                scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            if scaler is not None:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)

        total_loss += loss.detach().item() * norm.item()
        total_tokens += norm.item()

        if device.type == "cuda":
            torch.cuda.empty_cache()

    avg_loss = total_loss / max(1, total_tokens)
    return avg_loss, total_tokens


In [16]:
from datasets import load_dataset
import os, torch, torchaudio, pandas as pd
import soundfile as sf

ds = load_dataset("leduckhai/VietMed")

train_data = ds['test'].select(range(1500))
output_dir = "train"
os.makedirs(output_dir, exist_ok=True)

wav_paths = []
texts = []

for idx, item in enumerate(train_data):
    try:
        audio_data = item['audio']['array']
        sample_rate = item['audio']['sampling_rate']
        text = item['text']

        wav_path = os.path.join(output_dir, f"audio_{idx}.wav")
        sf.write(wav_path, audio_data, sample_rate)

        wav_paths.append(wav_path)
        texts.append(text.replace('\t', ' ').replace('\n', ' '))
    except Exception as e:
        print(f"{idx}: {e}")
        continue

df = pd.DataFrame({
    'wav': wav_paths,
    'txt': texts
})

df.to_csv('train_finetune.tsv', sep='\t', index=False, encoding='utf-8')

In [29]:
# Define configuration parameters
train_tsv = "train_finetune.tsv"
valid_tsv = None
init_model_dir = "output"
output_dir = "output2"
epochs = 30
batch_size = 4
num_workers = 2
peak_lr = 1e-5
warmup_steps = 15000
weight_decay = 0.0
seed = 42
device = "cuda" if torch.cuda.is_available() else "cpu"
amp = True
vocab_size = 5000
d_model = 192
num_blocks = 6
attention_heads = 4
linear_units = 1024
dropout = 0.1
disable_specaug = True
disable_speed_perturb = True

# Create a simple object to hold the configuration
class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Args(
    train_tsv=train_tsv,
    valid_tsv=valid_tsv,
    init_model_dir=init_model_dir,
    output_dir=output_dir,
    epochs=epochs,
    batch_size=batch_size,
    num_workers=num_workers,
    peak_lr=peak_lr,
    warmup_steps=warmup_steps,
    weight_decay=weight_decay,
    seed=seed,
    device=device,
    amp=amp,
    vocab_size=vocab_size,
    d_model=d_model,
    num_blocks=num_blocks,
    attention_heads=attention_heads,
    linear_units=linear_units,
    dropout=dropout,
    disable_specaug=disable_specaug,
    disable_speed_perturb=disable_speed_perturb
)

In [31]:
os.makedirs(args.output_dir, exist_ok=True)
set_seed(args.seed)

# Use same tokenizer as full-context training
init_model_dir = args.init_model_dir
init_spm_path = os.path.join(init_model_dir, "spm.model")
init_vocab_path = os.path.join(init_model_dir, "vocab.txt")

tokenizer = TextTokenizer(init_spm_path)


vocab_size = tokenizer.vocab_size()
config_path = build_config_yaml(
    args.output_dir,
    vocab_size=vocab_size,
    d_model=args.d_model,
    num_blocks=args.num_blocks,
    attention_heads=args.attention_heads,
    linear_units=args.linear_units,
    dropout_rate=args.dropout,
)
save_vocab_and_config(tokenizer, args.output_dir, config_path)

'output2/vocab.txt'

In [32]:
with open(config_path, "r") as f:
    configs = yaml.load(f, Loader=yaml.FullLoader)
model = init_model(configs, config_path)

# Load pre-trained full-context weights
init_ckpt = os.path.join(args.init_model_dir, "pytorch_model.bin")
state = torch.load(init_ckpt, map_location="cpu")
_ = model.load_state_dict(state, strict=False)

In [34]:
device = torch.device(args.device)
model = model.to(device)

train_ds = AudioDataset(
    args.train_tsv,
    tokenizer,
    use_speed_perturb=(not args.disable_speed_perturb),
)
train_loader = DataLoader(
    train_ds,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_workers,
    collate_fn=lambda b: collate_fn(b, apply_specaug=(not args.disable_specaug)),
    pin_memory=(device.type == "cuda"),
)

optimizer = torch.optim.Adam(model.parameters(), lr=args.peak_lr, betas=(0.9, 0.98), weight_decay=args.weight_decay)
scheduler = NoamLR(optimizer, warmup_steps=args.warmup_steps, peak_lr=args.peak_lr)
scaler = torch.cuda.amp.GradScaler() if (args.amp and device.type == "cuda") else None


  scaler = torch.cuda.amp.GradScaler() if (args.amp and device.type == "cuda") else None


In [38]:
import torchaudio.compliance.kaldi as kaldi

def compute_fbank(waveform: torch.Tensor, sample_rate: int = 16000) -> torch.Tensor:
    if waveform.dtype != torch.float32:
        waveform = waveform.to(torch.float32)
    if sample_rate != 16000:
        waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
        sample_rate = 16000
    feats = kaldi.fbank(
        waveform,
        num_mel_bins=80,
        frame_length=25,
        frame_shift=10,
        dither=0.0,
        energy_floor=0.0,
        sample_frequency=sample_rate,
    )
    return feats

In [39]:
for epoch in range(1, args.epochs + 1):
    print(f"Epoch {epoch}/{args.epochs}")
    train_loss, _ = train_one_epoch_chunk(
        model,
        train_loader,
        optimizer,
        scheduler,
        device,
        scaler,
        grad_accum_steps=1,
    )
    print(f"Train loss (CTC): {train_loss:.4f}")

torch.save(model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin"))
print("Saved chunk fine-tuned checkpoint.")

Epoch 1/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast(enabled=scaler is not None):


Train loss (CTC): 7.8539
Epoch 2/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 7.8059
Epoch 3/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 7.7385
Epoch 4/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 7.6221
Epoch 5/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 7.5028
Epoch 6/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 7.3570
Epoch 7/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 7.2305
Epoch 8/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 7.0812
Epoch 9/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 6.9702
Epoch 10/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 6.8260
Epoch 11/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 6.6667
Epoch 12/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 6.5359
Epoch 13/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 6.4020
Epoch 14/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 6.2783
Epoch 15/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 6.1447
Epoch 16/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 6.0092
Epoch 17/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 5.8984
Epoch 18/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 5.7583
Epoch 19/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 5.6512
Epoch 20/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 5.5289
Epoch 21/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 5.4011
Epoch 22/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 5.2807
Epoch 23/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 5.1728
Epoch 24/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 5.0579
Epoch 25/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 4.9451
Epoch 26/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 4.8488
Epoch 27/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 4.7327
Epoch 28/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 4.6221
Epoch 29/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 4.5254
Epoch 30/30


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Train loss (CTC): 4.4154
Saved chunk fine-tuned checkpoint.


In [None]:
# test_quality.py
import os, yaml, torch, torchaudio
import torchaudio.compliance.kaldi as kaldi
from datasets import Audio
import jiwer

from model.utils.init_model import init_model
from model.utils.ctc_utils import remove_duplicates_and_blank, class2str

ds = ds.cast_column("audio", Audio(sampling_rate=16000))

model_dir = "/content/drive/MyDrive/Chunkformer/output2"   
ckpt_name = "pytorch_model.bin" 
config_path = os.path.join(model_dir, "config.yaml")
vocab_path = os.path.join(model_dir, "vocab.txt")
ckpt_path = os.path.join(model_dir, ckpt_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
with open(config_path, "r") as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)
model = init_model(cfg, config_path).to(device)
state = torch.load(ckpt_path, map_location=device)
ret = model.load_state_dict(state, strict=False)
print("missing:", ret.missing_keys)
print("unexpected:", ret.unexpected_keys)
model.eval()

# 4) id->token
idx2tok = {}
with open(vocab_path, "r", encoding="utf-8") as f:
    for line in f:
        tok, idx = line.strip().rsplit(" ", 1)
        idx2tok[int(idx)] = tok

missing: []
unexpected: []


In [None]:
chunk_size, left_ctx, right_ctx = 48, 128, 64

def compute_fbank(waveform: torch.Tensor, sample_rate: int = 16000) -> torch.Tensor:
    if waveform.dtype != torch.float32:
        waveform = waveform.to(torch.float32)
    if sample_rate != 16000:
        waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
        sample_rate = 16000
    return kaldi.fbank(
        waveform,
        num_mel_bins=80, frame_length=25, frame_shift=10,
        dither=0.0, energy_floor=0.0, sample_frequency=sample_rate,
    )

@torch.no_grad()
def decode_sample(sample, diag=False) -> str:
    audio = sample["audio"]
    wav = torch.tensor(audio["array"]).float()
    if wav.dim() == 1: wav = wav.unsqueeze(0)
    elif wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)

    feats = compute_fbank(wav, audio["sampling_rate"])  # (T,80)
    xs = [feats.to(device)]
    xs_lens = torch.tensor([feats.size(0)], dtype=torch.int, device=device)
    offset = torch.zeros(1, dtype=torch.int, device=device)

    enc_outs, enc_lens, n_chunks, _, _, _ = model.encoder.forward_parallel_chunk(
        xs=xs,
        xs_origin_lens=xs_lens,
        chunk_size=chunk_size,
        left_context_size=left_ctx,
        right_context_size=right_ctx,
        offset=offset,
    )
    enc_padded, enc_masks = model.encoder.rearrange(enc_outs, xs_lens, n_chunks)  # (1,T,D)
    input_len = enc_masks.squeeze(1).sum(dim=1).to(torch.int)[0].item()

    log_probs = model.ctc.log_softmax(enc_padded)  # (1,T,V)
    ids = log_probs.argmax(dim=-1)[0, :input_len].tolist()
    if diag:
        blank_ratio = ids.count(0) / max(1, len(ids))
        mean_max_prob = log_probs.exp().max(-1).values[0, :input_len].mean().item()
        print(f"diag → T={input_len}, blank_ratio={blank_ratio:.2%}, mean_max_prob={mean_max_prob:.3f}, first30={ids[:30]}")

    ids = remove_duplicates_and_blank(ids)
    text = class2str(ids, idx2tok).replace("▁", " ").strip()
    return text

In [None]:
print("Quick check (5 samples):")
preds = []
refs = []
for s in ds["test"].select([1005, 1006, 1007, 1008, 1009]):
    txt = decode_sample(s, diag=True)
    preds.append(txt)
    refs.append(s.get("text", ""))
    print(f"pred: {txt}")

try:
    print("WER@5:", jiwer.wer(refs, preds))
    print("CER@5:", jiwer.cer(refs, preds))
except Exception as e:
    print("Skip WER/CER@5:", e)

Quick check (5 samples):
diag → T=86, blank_ratio=67.44%, mean_max_prob=0.831, first30=[0, 299, 0, 0, 0, 0, 0, 564, 0, 0, 0, 50, 0, 369, 0, 0, 725, 0, 0, 0, 0, 0, 295, 0, 1058, 0, 297, 0, 0, 0]
pred: phần thêm những kinh thức hay quen định rất là vấn đây là sẽ giống như cái trạng thuốc giảm đó hạn xuất như trên nào bây là
diag → T=86, blank_ratio=88.37%, mean_max_prob=0.863, first30=[0, 0, 0, 0, 0, 0, 389, 0, 0, 0, 236, 0, 0, 0, 0, 0, 0, 0, 0, 940, 0, 0, 308, 0, 0, 0, 728, 0, 0, 96]
pred: sự xương lắng giá tháng rất của rồi r ra
diag → T=61, blank_ratio=72.13%, mean_max_prob=0.768, first30=[0, 0, 0, 0, 0, 307, 0, 0, 0, 0, 0, 201, 0, 483, 0, 0, 676, 0, 0, 0, 0, 0, 29, 85, 0, 0, 463, 0, 488, 0]
pred: số đau kết hoàn thì để phòng cao giờ cái xương có hiện<unk> điện các biện
diag → T=74, blank_ratio=72.97%, mean_max_prob=0.870, first30=[0, 0, 0, 0, 0, 616, 0, 0, 450, 0, 518, 0, 0, 918, 296, 0, 0, 181, 0, 457, 0, 0, 0, 0, 0, 39, 0, 124, 0, 1171]
pred: soát sức khỏe ban đầu cơ bản mà bị dự đ