In [32]:
import os

os.chdir('../asr/')

In [33]:
from typing import Iterable, Tuple, List
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn.functional as F
import sentencepiece
import omegaconf
import pytorch_lightning as pl
import pandas as pd
from tqdm.auto import tqdm

from src.models import ConformerLAS, ConformerCTC
from src.metrics import WER

In [34]:
!git diff --no-index conf/conformer_ctc.yaml conf/conformer_ctc_wide.yaml

[1mdiff --git a/conf/conformer_ctc.yaml b/conf/conformer_ctc_wide.yaml[m
[1mindex ddc568d..702a599 100755[m
[1m--- a/conf/conformer_ctc.yaml[m
[1m+++ b/conf/conformer_ctc_wide.yaml[m
[36m@@ -7,11 +7,11 @@[m [mmodel:[m
     dropout: 0.0[m
     feat_in: 64[m
     stride: 4[m
[31m-    d_model: 256[m
[31m-    n_layers: 10[m
[32m+[m[32m    d_model: 320[m
[32m+[m[32m    n_layers: 8[m
     n_heads: 8[m
     ff_exp_factor: 2[m
[31m-    kernel_size: 7[m
[32m+[m[32m    kernel_size: 15[m
     [m
 [m
   decoder:[m


In [36]:
def init_model(model: pl.LightningModule, ckpt_path: str) -> pl.LightningModule:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(ckpt)
    model.eval()
    model.freeze()
    return model


def compute_wer(refs: Iterable[str], hyps: Iterable[str]) -> float:
    wer = WER()
    wer.update(refs, hyps)
    return wer.compute()[0].item()


class GreedyDecoderLAS:
    def __init__(self, model: ConformerLAS, max_steps=20):
        self.model = model
        self.max_steps = max_steps

    def __call__(self, encoded: torch.Tensor) -> str:
        
        tokens = [self.model.decoder.tokenizer.bos_id()]

        for _ in range(self.max_steps):
            
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )
        
            best_next_token = distribution[0, -1].argmax()
            
            if best_next_token == self.model.decoder.tokenizer.eos_id():
                break

            tokens.append(best_next_token.item())
        
        return self.model.decoder.tokenizer.decode(tokens)

# Single Model

In [37]:
dataset = 'test_opus/farfield/manifest.jsonl'

## LAS

In [38]:
conf = omegaconf.OmegaConf.load("./conf/conformer_las.yaml")
conf.val_dataloader.dataset.manifest_name = dataset
conf.model.decoder.tokenizer = "./data/tokenizer/bpe_1024_bos_eos.model"

conformer_las = init_model(
    model=ConformerLAS(conf=conf),
    ckpt_path="./data/conformer_las_2epochs.ckpt"
)

In [39]:
print(sum(p.numel() for p in conformer_las.parameters()))

9841920


In [40]:
las_decoder = GreedyDecoderLAS(conformer_las)

refs, hyps_las = [], []

for batch in tqdm(conformer_las.val_dataloader()):

    features, features_len, targets, target_len = batch

    encoded, encoded_len = conformer_las(features, features_len)
    
    for i in range(features.shape[0]):

        encoder_states = encoded[
            [i],
            :encoded_len[i],
            :
        ]

        ref_tokens = targets[i, :target_len[i]].tolist()

        refs.append(
            conformer_las.decoder.tokenizer.decode(ref_tokens)
        )
        hyps_las.append(
            las_decoder(encoder_states)
        )

  0%|          | 0/60 [00:01<?, ?it/s]

In [27]:
compute_wer(refs, hyps_las)

0.42278361320495605

## CTC

In [28]:
# TODO: load models, estimate WER

In [29]:
def decode_ctc_hyps(model: ConformerCTC) -> Tuple[List[str], List[str]]:
    return [], []

In [41]:
conf = omegaconf.OmegaConf.load("./conf/conformer_ctc.yaml")
conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc = init_model(
    model=ConformerCTC(conf=conf),
    ckpt_path="./data/conformer_7epochs_state_dict.ckpt"
)

refs, hyps_ctc = decode_ctc_hyps(conformer_ctc)
print(sum(p.numel() for p in conformer_ctc.parameters()))

9989410


In [42]:
conf = omegaconf.OmegaConf.load("./conf/conformer_ctc_wide.yaml")
conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc_wide = init_model(
    model=ConformerCTC(conf=conf),
    ckpt_path="./data/conformer_wide_7epochs_state_dict.ckpt"
)

refs, hyps_ctc_wide = decode_ctc_hyps(conformer_ctc_wide)
print(sum(p.numel() for p in conformer_ctc_wide.parameters()))

12486114


In [None]:
'../week07/images/rover_table.png'

# ROVER: Recognizer Output Voting Error Reduction — 5 points

* [A post-processing system to yield reduced word error rates: Recognizer Output Voting Error Reduction (ROVER)](https://ieeexplore.ieee.org/document/659110)
* [Improved ROVER using Language Model Information](https://www-tlp.limsi.fr/public/asr00_holger.pdf)

Alignment + Voting

![](./images/rover_table.png)

In [None]:
from crowdkit.aggregation.texts import ROVER

In [None]:
# TODO: aggregate hypotheses, estimate WER

# MBR: Minimum Bayes Risk — 5 points


* [Minimum Bayes Risk Decoding and System
Combination Based on a Recursion for Edit Distance](https://danielpovey.com/files/csl11_consensus.pdf)
* [mbr-decoding blog-post](https://suzyahyah.github.io/bayesian%20inference/machine%20translation/2022/02/15/mbr-decoding.html)
* [Combination of end-to-end and hybrid models for speech recognition](http://www.interspeech2020.org/uploadfile/pdf/Tue-1-8-4.pdf)


![](./images/mbr_scheme.png)

In [None]:
# TODO: retrieve minimum-Distance hypothesis, estimate WER