In [1]:
import logging
import math
from collections import namedtuple
from typing import List, Tuple, Optional

import sentencepiece as spm
import torch
import torchaudio
from pytorch_lightning import LightningModule
from torchaudio.models import emformer_rnnt_model, Hypothesis, RNNTBeamSearch
from torchaudio.models import Conformer, RNNT
from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber

In [2]:
def emformer_rnnt_base(num_symbols: int) -> RNNT:
    r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
    Args:
        num_symbols (int): The size of target token lexicon.
    Returns:
        RNNT:
            Emformer RNN-T model.
    """
    return emformer_rnnt_model(
        input_dim=80,
        encoding_dim=1024,
        num_symbols=num_symbols,
        segment_length=16,
        right_context_length=4,
        time_reduction_input_dim=128,
        time_reduction_stride=4,
        transformer_num_heads=8,
        transformer_ffn_dim=1024,
        transformer_num_layers=16,
        transformer_dropout=0.1,
        transformer_activation="gelu",
        transformer_left_context_length=30,
        transformer_max_memory_size=0,
        transformer_weight_init_scale_strategy="depthwise",
        transformer_tanh_on_mem=True,
        symbol_embedding_dim=512,
        num_lstm_layers=2,
        lstm_layer_norm=True,
        lstm_layer_norm_epsilon=1e-3,
        lstm_dropout=0.3,
    )

In [3]:
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])

def post_process_hypos(
    hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
    tokens_idx = 0
    score_idx = 3
    post_process_remove_list = [
        sp_model.unk_id(),
        sp_model.eos_id(),
        sp_model.pad_id(),
    ]
    filtered_hypo_tokens = [
        [token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
    ]
    hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
    hypos_ids = [h[tokens_idx][1:] for h in hypos]
    hypos_score = [[math.exp(h[score_idx])] for h in hypos]

    nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))

    return nbest_batch


class ConformerRNNTModule(LightningModule):
    def __init__(self, sp_model):
        super().__init__()

        self.sp_model = sp_model
        spm_vocab_size = self.sp_model.get_piece_size()
        self.blank_idx = spm_vocab_size

        # ``conformer_rnnt_base`` hardcodes a specific Conformer RNN-T configuration.
        # For greater customizability, please refer to ``conformer_rnnt_model``.
        self.model = emformer_rnnt_base(num_symbols=1024)

    def forward(self, batch: Batch):
        decoder = RNNTBeamSearch(self.model, self.blank_idx)
        hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
        return post_process_hypos(hypotheses, self.sp_model)[0][0]

    def training_step(self, batch: Batch, batch_idx = None):
        """Custom training step.
        By default, DDP does the following on each train step:
        - For each GPU, compute loss and gradient on shard of training data.
        - Sync and average gradients across all GPUs. The final gradient
          is (sum of gradients across all GPUs) / N, where N is the world
          size (total number of GPUs).
        - Update parameters on each GPU.
        Here, we do the following:
        - For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
          the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
        - Sync and average gradients across all GPUs. The final gradient
          is (sum of gradients across all GPUs) / B_total.
        - Update parameters on each GPU.
        Doing so allows us to account for the variability in batch sizes that
        variable-length sequential data yield.
        """
        loss = self._step(batch, batch_idx, "train")
        batch_size = batch.features.size(0)
        batch_sizes = self.all_gather(batch_size)
        self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True)
        loss *= batch_sizes.size(0) / batch_sizes.sum()  # world size / batch size
        return loss

    def validation_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, "test")

In [4]:
import malaya_speech
from malaya_speech.utils import torch_featurization

`pyaudio` is not available, `malaya_speech.streaming.pyaudio` is not able to use.


In [5]:
y, _ = malaya_speech.load('speech/example-speaker/husein-zolkepli.wav')

In [6]:
sp_model = spm.SentencePieceProcessor(model_file='/home/husein/malaya-speech/malay-tts.model')
global_stats = torch_featurization.GlobalStatsNormalization('malay-stats.json')

In [7]:
model = ConformerRNNTModule.load_from_checkpoint('emformer-base-32/model-epoch=19-step=2040000.ckpt',
                                                 sp_model=sp_model).eval()

In [8]:
torch.save(model.model.state_dict(), 'emformer-base.pt')

In [9]:
mel = torch_featurization.melspectrogram(y)
mel = torch_featurization.piecewise_linear_log(mel)
mel = global_stats(mel)
mel = torch.nn.functional.pad(mel, pad=(0,0,0,4))
mel.shape

torch.Size([568, 80])

In [10]:
decoder = RNNTBeamSearch(model.model, model.blank_idx)
hypotheses = decoder(mel, torch.Tensor((len(mel),)), 20)
post_process_hypos(hypotheses, model.sp_model)[0][0]

'testing nama saya husin bin zulkafli'

In [11]:
import json

with open('/home/husein/ssd1/speech-bahasa/malay-asr-test.json') as fopen:
    test_set = json.load(fopen)

In [12]:
def calculate_cer(actual, hyp):
    """
    Calculate CER using `python-Levenshtein`.
    """
    import Levenshtein as Lev

    actual = actual.replace(' ', '')
    hyp = hyp.replace(' ', '')
    return Lev.distance(actual, hyp) / len(actual)


def calculate_wer(actual, hyp):
    """
    Calculate WER using `python-Levenshtein`.
    """
    import Levenshtein as Lev

    b = set(actual.split() + hyp.split())
    word2char = dict(zip(b, range(len(b))))

    w1 = [chr(word2char[w]) for w in actual.split()]
    w2 = [chr(word2char[w]) for w in hyp.split()]

    return Lev.distance(''.join(w1), ''.join(w2)) / len(actual.split())

In [13]:
from tqdm import tqdm

wer, cer = [], []

for i in tqdm(range(len(test_set['X']))):
    batch_y = [test_set['Y'][i]]
    y = malaya_speech.load(test_set['X'][i])[0]
    mel = torch_featurization.melspectrogram(y)
    mel = torch_featurization.piecewise_linear_log(mel)
    mel = global_stats(mel)
    
    hypotheses = decoder(mel, torch.Tensor((len(mel),)), 20)
    pred = post_process_hypos(hypotheses, model.sp_model)[0][0]
    
    wer.append(calculate_wer(test_set['Y'][i], pred))
    cer.append(calculate_cer(test_set['Y'][i], pred))

100%|█████████████████████████████████████████| 739/739 [13:40<00:00,  1.11s/it]


In [14]:
import numpy as np

np.mean(wer), np.mean(cer)

(0.18303839134234529, 0.07738533622881417)

In [16]:
with open('/home/husein/malaya-speech/postprocess-malaya-malay-test-set.json') as fopen:
    malaya_malay = json.load(fopen)

In [17]:
wer, cer = [], []

for i in tqdm(range(len(malaya_malay))):
    if not malaya_malay[i]['accept']:
        continue
    
    y = malaya_speech.load(f'/home/husein/malaya-speech/malay-test/{i}.wav')[0]
    mel = torch_featurization.melspectrogram(y)
    mel = torch_featurization.piecewise_linear_log(mel)
    mel = global_stats(mel)
    
    hypotheses = decoder(mel, torch.Tensor((len(mel),)), 20)
    pred = post_process_hypos(hypotheses, model.sp_model)[0][0]
    
    wer.append(calculate_wer(malaya_malay[i]['cleaned'], pred))
    cer.append(calculate_cer(malaya_malay[i]['cleaned'], pred))

100%|█████████████████████████████████████████| 765/765 [05:33<00:00,  2.29it/s]


In [18]:
np.mean(wer), np.mean(cer)

(0.1757624237861392, 0.062339190005373434)

In [19]:
from malaya_boilerplate.huggingface import upload_dict

In [20]:
files_mapping = {'emformer-base.pt': 'model.pt',
                 '/home/husein/malaya-speech/malay-tts.model': 'malay-stt.model',
                'malay-stats.json': 'malay-stats.json'}
upload_dict(model = 'emformer-base', files_mapping = files_mapping, username = 'mesolitica')

<class 'requests.exceptions.HTTPError'> (Request ID: Root=1-63edd481-327bf3d869f5fb1c6a02b535)

You already created this model repo - You already created this model repo
