In [None]:
import os
os.chdir('../')

In [None]:
import torch
import torch.nn.functional as F
import sentencepiece as spm
import librosa
import numpy as np
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from transformers import pipeline

from models.encoder import AudioEncoder
from models.decoder import Decoder
from models.jointer import Jointer

from constants import SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MELS
from constants import RNNT_BLANK, PAD, VOCAB_SIZE, TOKENIZER_MODEL_PATH, MAX_SYMBOLS
from constants import ATTENTION_CONTEXT_SIZE
from constants import N_STATE, N_LAYER, N_HEAD

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load trained model

In [None]:
# This model is FP32
trained_model_path = hf_hub_download(
    repo_id="hkab/vietnamese-asr-model", 
    filename="rnnt-latest.ckpt",
    subfolder="rnnt-whisper-small/80_3"
)

In [None]:
checkpoint = torch.load(trained_model_path, map_location="cpu", weights_only=True )

encoder_weight = {}
decoder_weight = {}
joint_weight = {}

for k, v in checkpoint['state_dict'].items():
    if 'alibi' in k:
        continue
    if 'encoder' in k:
        encoder_weight[k.replace('encoder.', '')] = v
    elif 'decoder' in k:
        decoder_weight[k.replace('decoder.', '')] = v
    elif 'joint' in k:
        joint_weight[k.replace('joint.', '')] = v

In [None]:
encoder = AudioEncoder(
    N_MELS,
    n_state=N_STATE,
    n_head=N_HEAD,
    n_layer=N_LAYER,
    att_context_size=ATTENTION_CONTEXT_SIZE
)

decoder = Decoder(vocab_size=VOCAB_SIZE + 1)

joint = Jointer(vocab_size=VOCAB_SIZE + 1)

encoder.load_state_dict(encoder_weight, strict=False)
decoder.load_state_dict(decoder_weight, strict=False)
joint.load_state_dict(joint_weight, strict=False)

In [None]:
encoder = encoder.to(DEVICE)
decoder = decoder.to(DEVICE)
joint = joint.to(DEVICE)

encoder.eval()
decoder.eval()
joint.eval()

In [None]:
def mel_filters(device, n_mels: int) -> torch.Tensor:
    """
    load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
    Allows decoupling librosa dependency; saved using:

        np.savez_compressed(
            "mel_filters.npz",
            mel_80=librosa.filters.mel(sr=SAMPLE_RATE, n_fft=400, n_mels=80),
            mel_128=librosa.filters.mel(sr=SAMPLE_RATE, n_fft=400, n_mels=128),
        )
    """
    assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

    with np.load("./utils/mel_filters.npz", allow_pickle=False) as f:
        return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)

def log_mel_spectrogram(
    audio, n_mels, padding, streaming, device
):

    if device is not None:
        audio = audio.to(device)
    if padding > 0:
        audio = F.pad(audio, (0, padding))
    window = torch.hann_window(N_FFT).to(audio.device)
    if not streaming:
        stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
    else:
        stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, center=False, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2

    filters = mel_filters(audio.device, n_mels)
    mel_spec = filters @ magnitudes

    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    # log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec

## Online and offline inference

In [None]:
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)

In [None]:
def offline_transcribe(audio, encoder, decoder, joint, tokenizer, max_symbols=3):
    
    mels = log_mel_spectrogram(audio=audio, n_mels=N_MELS, padding=0, streaming=False, device="cpu")
    x = mels.reshape(1, *mels.shape).to(DEVICE)
    x_len = torch.tensor([x.shape[2]]).to(DEVICE)

    enc_out, _ = encoder(x, x_len)

    all_sentences = []
    # greedy decoding, handle each sequence independently for easier implementation
    for batch_idx in range(enc_out.shape[0]):
        hypothesis = [[None, None]]  # [label, state]
        seq_enc_out = enc_out[batch_idx, :, :].unsqueeze(0) # [1, T, D]
        seq_ids = []
        
        for time_idx in range(seq_enc_out.shape[1]):
            curent_seq_enc_out = seq_enc_out[:, time_idx, :].unsqueeze(1) # 1, 1, D

            not_blank = True
            symbols_added = 0

            while not_blank and symbols_added < max_symbols:
                # In the first timestep, we initialize the network with RNNT Blank
                # In later timesteps, we provide previous predicted label as input.
                if hypothesis[-1][0] is None:
                    last_token = torch.tensor([[RNNT_BLANK]], dtype=torch.long, device=seq_enc_out.device)
                    last_seq_h_n = None
                else:
                    last_token = hypothesis[-1][0]
                    last_seq_h_n = hypothesis[-1][1]
                
                if last_seq_h_n is None:
                    current_seq_dec_out, current_seq_h_n = decoder(last_token)
                else:
                    current_seq_dec_out, current_seq_h_n = decoder(last_token, last_seq_h_n)
                logits = joint(curent_seq_enc_out, current_seq_dec_out)[0, 0, 0, :]  # (B, T=1, U=1, V + 1)

                del current_seq_dec_out

                _, token_id = logits.max(0)
                token_id = token_id.detach().item()  # K is the label at timestep t_s in inner loop, s >= 0.

                del logits

                if token_id == RNNT_BLANK:
                    not_blank = False
                else:
                    symbols_added += 1
                    hypothesis.append([
                        torch.tensor([[token_id]], dtype=torch.long, device=curent_seq_enc_out.device),
                        current_seq_h_n
                    ])
                    seq_ids.append(token_id)
        all_sentences.append(tokenizer.decode(seq_ids))
    return all_sentences

In [None]:
audio = librosa.load("/path/to/audio.wav", sr=SAMPLE_RATE)[0]
audio = torch.from_numpy(audio).to(DEVICE)
offline_transcribe(audio, encoder, decoder, joint, tokenizer, max_symbols=3)[0]

In [None]:
def online_transcribe(audio, encoder, decoder, joint, tokenizer, max_symbols=3):
    audio_cache = torch.zeros(240, device=DEVICE) # audio[:240]

    conv1_cache = torch.zeros(1, 80, 1, device=DEVICE)
    conv2_cache = torch.zeros(1, 768, 1, device=DEVICE)
    conv3_cache = torch.zeros(1, 768, 1, device=DEVICE)

    k_cache = torch.zeros(12, 1, ATTENTION_CONTEXT_SIZE[0], 768, device=DEVICE)
    v_cache = torch.zeros(12, 1, ATTENTION_CONTEXT_SIZE[0], 768, device=DEVICE)
    cache_len = torch.zeros(1, dtype=torch.int, device=DEVICE)

    hypothesis = [[None, None]]  # [label, state]
    seq_ids = []

    for i in tqdm(range(240, audio.shape[0], HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH))):
        audio_chunk = torch.cat([audio_cache, audio[i:i+HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH)]])
        # print(f"From {i - 240} to {i+HOP_LENGTH * 31 + N_FFT - (N_FFT - HOP_LENGTH)}")
        if audio_chunk.shape[0] < HOP_LENGTH * 31 + N_FFT:
            audio_chunk = F.pad(audio_chunk, (0, HOP_LENGTH * 31 + N_FFT - audio_chunk.shape[0]))
        audio_cache = audio_chunk[-(N_FFT - HOP_LENGTH):]
        x_chunk = log_mel_spectrogram(audio=audio_chunk, n_mels=N_MELS, padding=0, streaming=True, device="cuda")
        x_chunk = x_chunk.reshape(1, *x_chunk.shape)

        if x_chunk.shape[-1] < 32:
            x_chunk = F.pad(x_chunk, (0, 32 - x_chunk.shape[-1]))
        x_chunk = torch.cat([conv1_cache, x_chunk], dim=2)

        conv1_cache = x_chunk[:, :, -1].unsqueeze(2)
        x_chunk = F.gelu(encoder.conv1(x_chunk))

        x_chunk = torch.cat([conv2_cache, x_chunk], dim=2)
        conv2_cache = x_chunk[:, :, -1].unsqueeze(2)
        x_chunk = F.gelu(encoder.conv2(x_chunk))
        
        x_chunk = torch.cat([conv3_cache, x_chunk], dim=2)
        conv3_cache = x_chunk[:, :, -1].unsqueeze(2)
        x_chunk = F.gelu(encoder.conv3(x_chunk))

        x_chunk = x_chunk.permute(0, 2, 1)

        x_len = torch.tensor([x_chunk.shape[1]]).to(DEVICE)
        if k_cache is not None:
            x_len = x_len + ATTENTION_CONTEXT_SIZE[0]
            offset = torch.neg(cache_len) + ATTENTION_CONTEXT_SIZE[0]
        else:
            offset = None

        attn_mask = encoder.form_attention_mask_for_streaming(encoder.att_context_size, x_len, offset.to(DEVICE), DEVICE)

        if k_cache is not None:
            attn_mask = attn_mask[:, :, ATTENTION_CONTEXT_SIZE[0]:, :]

        new_k_cache = []
        new_v_cache = []
        for i, block in enumerate(encoder.blocks):
            x_chunk, layer_k_cache, layer_v_cache = block(x_chunk, mask=attn_mask, k_cache=k_cache[i], v_cache=v_cache[i])
            new_k_cache.append(layer_k_cache)
            new_v_cache.append(layer_v_cache)

        enc_out = encoder.ln_post(x_chunk)

        k_cache = torch.stack(new_k_cache, dim=0)
        v_cache = torch.stack(new_v_cache, dim=0)
        cache_len = torch.clamp(cache_len + ATTENTION_CONTEXT_SIZE[-1] + 1, max=ATTENTION_CONTEXT_SIZE[0])

        # Greedy decoding
        seq_enc_out = enc_out[0, :, :].unsqueeze(0) # [1, T, D]
        
        for time_idx in range(seq_enc_out.shape[1]):
            curent_seq_enc_out = seq_enc_out[:, time_idx, :].unsqueeze(1) # 1, 1, D

            not_blank = True
            symbols_added = 0

            while not_blank and symbols_added < max_symbols:
                # In the first timestep, we initialize the network with RNNT Blank
                # In later timesteps, we provide previous predicted label as input.
                if hypothesis[-1][0] is None:
                    last_token = torch.tensor([[RNNT_BLANK]], dtype=torch.long, device=seq_enc_out.device)
                    last_seq_h_n = None
                else:
                    last_token = hypothesis[-1][0]
                    last_seq_h_n = hypothesis[-1][1]
                
                if last_seq_h_n is None:
                    current_seq_dec_out, current_seq_h_n = decoder(last_token)
                else:
                    current_seq_dec_out, current_seq_h_n = decoder(last_token, last_seq_h_n)
                logits = joint(curent_seq_enc_out, current_seq_dec_out)[0, 0, 0, :]  # (B, T=1, U=1, V + 1)

                del current_seq_dec_out

                _, token_id = logits.max(0)
                token_id = token_id.detach().item()  # K is the label at timestep t_s in inner loop, s >= 0.

                del logits

                if token_id == RNNT_BLANK:
                    not_blank = False
                else:
                    symbols_added += 1
                    hypothesis.append([
                        torch.tensor([[token_id]], dtype=torch.long, device=curent_seq_enc_out.device),
                        current_seq_h_n
                    ])
                    seq_ids.append(token_id)
    return tokenizer.decode(seq_ids)

In [None]:
audio = librosa.load("/path/to/audio.wav", sr=SAMPLE_RATE)[0]
audio = torch.from_numpy(audio).to(DEVICE)
online_transcribe(audio, encoder, decoder, joint, tokenizer)