In [1]:
import torch
from chunkformer_vpb.training.finetune_utils import (
    get_default_args,
    prepare_input_file,
    load_model_only,
    GreedyTokenizer,
    compute_chunkformer_loss,
)

def main():
    # 1. Chuẩn bị args và device
    args = get_default_args()
    args.model_checkpoint = "../../../chunkformer-large-vie"     # đường dẫn đến folder checkpoint
    # args.audio_path       = "../../../debug_wavs/sample_19.wav"  # file audio mẫu
    # args.label_text       = "cần nắm bắt xu hướng phát triển công nghệ mới"  # label ground-truth
    args.audio_path       = "../../../debug_wavs/utt_000664.wav"  # file audio mẫu
    args.label_text       = "một giọng nói du dương không thể lẫn với ai khác cất lên"  # label ground-truth


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 2. Load model + tokenizer
    model, _ = load_model_only(args.model_checkpoint, device)
    model.ctc_weight = 0.3
    # model.reverse_weight = 0.3 -> can not work due to there is no right_decoder 
    tokenizer = GreedyTokenizer(vocab_path=f"{args.model_checkpoint}/vocab.txt")

    # 3. Prepare input features
    xs = prepare_input_file(args.audio_path, device)  # [1, T_raw, 80]

    # 4. Compute loss
    loss_dict = compute_chunkformer_loss(
        model=model,
        tokenizer=tokenizer,
        xs=xs,
        args=args,
        label_text=args.label_text,
        device=device
    )

    # 5. In kết quả
    print(f"Loss: {loss_dict['loss'].item()}")
    print(f"CTC Loss: {loss_dict['loss_ctc'].item()}")
    print(f"AED Loss: {loss_dict['loss_att'].item()}")

main()



🧾 Loaded checkpoint from: ../../../chunkformer-large-vie/pytorch_model.bin
📦 Checkpoint keys: ['encoder.global_cmvn.mean', 'encoder.global_cmvn.istd', 'encoder.embed.out.weight', 'encoder.embed.out.bias', 'encoder.embed.conv.0.weight'] ... (total 813)
🔍 AED decoder head included in checkpoint? ✅ YES
📊 Model total params: 113,852,240, trainable: 113,852,240

📥 Input shape: torch.Size([1, 423, 80]), xs_origin_lens: [423]
⚙️ chunk_size=64, left_context=128, right_context=128, truncated_context_size=11200
📏 Subsampling: 8, Chunk frame size: 519, Step: 512, Conv lorder: 7

🧱 Total chunked xs shape: torch.Size([1, 519, 80])
📐 xs_lens (post chunk): torch.Size([1]), total_chunks: 1
🎛️ Embedded xs shape: torch.Size([1, 64, 512]), PosEmb shape: torch.Size([1, 383, 512])
🧮 att_mask shape: torch.Size([1, 1, 320]), mask_pad shape: torch.Size([1, 1, 78])
🧩 Layer 0: xs shape after layer = torch.Size([1, 64, 512])
🧩 Layer 1: xs shape after layer = torch.Size([1, 64, 512])
🧩 Layer 2: xs shape after la