In [1]:
from typing import Iterable, Tuple, List
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn.functional as F
import sentencepiece
import omegaconf
import pytorch_lightning as pl
import pandas as pd
import editdistance
from tqdm.auto import tqdm
from torchsummary import summary

from src.models import ConformerLAS, ConformerCTC
from src.metrics import WER

In [133]:
def init_model(model: pl.LightningModule, ckpt_path: str) -> pl.LightningModule:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(ckpt)
    model.eval()
    model.freeze()
    return model


def compute_wer(refs: Iterable[str], hyps: Iterable[str]) -> float:
    wer = WER()
    wer.update(refs, hyps)
    return wer.compute()[0].item()


class GreedyDecoderLAS:
    def __init__(self, model: ConformerLAS, max_steps=20):
        self.model = model
        self.max_steps = max_steps

    def __call__(self, encoded: torch.Tensor, ent: torch.Tensor = None) -> str:
        
        tokens = [self.model.decoder.tokenizer.bos_id()]
        
        for _ in range(self.max_steps):
            
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )
        
            best_next_token = distribution[0, -1].argmax()

            
            if best_next_token == self.model.decoder.tokenizer.eos_id():
                break

            tokens.append(best_next_token.item())

        return self.model.decoder.tokenizer.decode(tokens)

# Single Model

In [134]:
dataset = 'test_opus/farfield/manifest.jsonl'

## LAS

In [135]:
conf = omegaconf.OmegaConf.load("./conf/conformer_las.yaml")
conf.val_dataloader.dataset.manifest_name = dataset
conf.model.decoder.tokenizer = "./data/tokenizer/bpe_1024_bos_eos.model"

conformer_las = init_model(
    model=ConformerLAS(conf=conf),
    ckpt_path="./data/conformer_las_2epochs.ckpt"
)

In [136]:
las_decoder = GreedyDecoderLAS(conformer_las)

refs, hyps_las = [], []

for batch in tqdm(conformer_las.val_dataloader()):

    features, features_len, targets, target_len = batch

    encoded, encoded_len = conformer_las(features, features_len)
    
    for i in range(features.shape[0]):

        encoder_states = encoded[[i], :encoded_len[i], :]

        ref_tokens = targets[i, :target_len[i]].tolist()

        refs.append(
            conformer_las.decoder.tokenizer.decode(ref_tokens)
        )
        hyps_las.append(
            las_decoder(encoder_states)
        )

  0%|          | 0/479 [00:01<?, ?it/s]

In [138]:
compute_wer(refs, hyps_las)

0.42290276288986206

## CTC

In [10]:
# TODO: load models, estimate WER

In [15]:
def decode_ctc_hyps(model: ConformerCTC) -> Tuple[List[str], List[str]]:
    
    refs, hyps_ctc = [], []

    for batch in tqdm(model.val_dataloader()):

        features, features_len, targets, target_len = batch

        encoded, encoded_len, preds = model(features, features_len)

        refs.extend(
            model.decoder.decode(targets, target_len, unique_consecutive=True)
        )
        hyps_ctc.extend(
            model.decoder.decode(preds, encoded_len, unique_consecutive=True)
        )

    return refs, hyps_ctc

In [16]:
conf = omegaconf.OmegaConf.load("./conf/conformer_ctc.yaml")
conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc = init_model(
    model=ConformerCTC(conf=conf),
    ckpt_path="./data/conformer_7epochs_state_dict.ckpt"
)

refs, hyps_ctc = decode_ctc_hyps(conformer_ctc)

  0%|          | 0/479 [00:01<?, ?it/s]

In [17]:
compute_wer(refs, hyps_ctc)

0.431997150182724

In [19]:
conf = omegaconf.OmegaConf.load("./conf/conformer_ctc_wide.yaml")
conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc_wide = init_model(
    model=ConformerCTC(conf=conf),
    ckpt_path="./data/conformer_wide_7epochs_state_dict.ckpt"
)

refs, hyps_ctc_wide = decode_ctc_hyps(conformer_ctc_wide)

  0%|          | 0/479 [00:01<?, ?it/s]

In [20]:
compute_wer(refs, hyps_ctc_wide)

0.38093534111976624

# ROVER: Recognizer Output Voting Error Reduction — 5 points

* [A post-processing system to yield reduced word error rates: Recognizer Output Voting Error Reduction (ROVER)](https://ieeexplore.ieee.org/document/659110)
* [Improved ROVER using Language Model Information](https://www-tlp.limsi.fr/public/asr00_holger.pdf)

Alignment + Voting

![](./images/rover_table.png)

In [None]:
from crowdkit.aggregation.texts import ROVER

In [56]:
# TODO: aggregate hypotheses, estimate WER

In [57]:
# Поскольку мы не меняли dataloader, а случайности в нём нет, то все 3 refs имеют одинаковый порядок
# Поскольку GroupBy в fit_predict возвращает значения в порядке оригинального ДатаФрейма, то для выравнивания
# достаточно вставлять модели от меньшего wer'a к большему
data = {"task": list(range(len(refs))) * 3, "text": hyps_ctc_wide + hyps_las + hyps_ctc}

In [362]:
df = pd.DataFrame(data=data)

In [363]:
df['text'] = df['text'].apply(lambda s: s.lower())
tokenizer = lambda s: s.split(' ')
detokenizer = lambda tokens: ' '.join(tokens)
result = pd.DataFrame(data=ROVER(tokenizer, detokenizer).fit_predict(df))

In [364]:
list(result.index) == list(range(len(refs)))

In [365]:
hyps_rover = list(result['agg_text'])
compute_wer(refs, hyps_rover)

True

# MBR: Minimum Bayes Risk — 5 points


* [Minimum Bayes Risk Decoding and System
Combination Based on a Recursion for Edit Distance](https://danielpovey.com/files/csl11_consensus.pdf)
* [mbr-decoding blog-post](https://suzyahyah.github.io/bayesian%20inference/machine%20translation/2022/02/15/mbr-decoding.html)
* [Combination of end-to-end and hybrid models for speech recognition](http://www.interspeech2020.org/uploadfile/pdf/Tue-1-8-4.pdf)

![](./images/mbr_scheme.png)

In [None]:
# TODO: retrieve minimum-Distance hypothesis, estimate WER

In [None]:
class EstimationLAS:
    def __init__(self, model: ConformerLAS, max_steps=20):
        self.model = model
        self.max_steps = max_steps

    def __call__(self, encoded: torch.Tensor, est: torch.Tensor) -> str:
        
        tokens = [self.model.decoder.tokenizer.bos_id()]
        
        inds = self.model.decoder.tokenizer.encode(est)
        
        prob = 0

        for i in range(self.max_steps):
            
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )
            if (i == len(inds)):
                break
            prob += distribution[0, -1][inds[i]]
            
            best_next_token = inds[i]
            
            if best_next_token == self.model.decoder.tokenizer.eos_id():
                break

            tokens.append(best_next_token)
        
        return prob

In [139]:
las_estimator = EstimationLAS(conformer_las)

las_estimate_las, las_estimate_ctc, las_estimate_ctc_wide = [], [], []

j = 0

for batch in tqdm(conformer_las.val_dataloader()):

    features, features_len, targets, target_len = batch

    encoded, encoded_len = conformer_las(features, features_len)
    
    for i in range(features.shape[0]):

        encoder_states = encoded[[i], :encoded_len[i], :]

        las_estimate_las.append(las_estimator(encoder_states, hyps_las[j]))
        las_estimate_ctc.append(las_estimator(encoder_states, hyps_ctc[j]))
        las_estimate_ctc_wide.append(las_estimator(encoder_states, hyps_ctc_wide[j]))
        j += 1

In [144]:
for i, st in enumerate(hyps_las):
    for ss in st:
        if ss not in conformer_ctc.decoder.labels:
            print(refs[i])

  0%|          | 0/479 [00:01<?, ?it/s]

In [163]:
tokens_to_inds = {conformer_ctc.decoder.labels[i]: i for i in range(len(conformer_ctc.decoder.labels))}
tokens_to_inds['⁇'] = tokens_to_inds['ъ']

In [166]:
def ctc_estimate_hyp(model: ConformerCTC, hyps_las, hyps_ctc, hyps_ctc_wide) -> Tuple[List[str], List[str]]:
    
    j = 0
    
    model_estimate_ctc, model_estimate_las, model_estimate_ctc_wide = [], [], []
    
    ctc_loss = torch.nn.CTCLoss(blank=model.decoder.blank_id)
    
    for batch in tqdm(model.val_dataloader()):

        features, features_len, targets, target_len = batch

        encoded, encoded_len, preds = model(features, features_len)
        
        for i in range(features.shape[0]):
            
            inds_las = torch.Tensor([tokens_to_inds[char] for char in hyps_las[j]])
            inds_ctc = torch.Tensor([tokens_to_inds[char] for char in hyps_ctc[j]])
            inds_ctc_wide = torch.Tensor([tokens_to_inds[char] for char in hyps_ctc_wide[j]])

            model_estimate_las.append(ctc_loss(encoded[i], inds_las, \
                                               torch.IntTensor([encoded_len[0]]), torch.IntTensor([len(inds_las)])))
            model_estimate_ctc.append(ctc_loss(encoded[i], inds_ctc,\
                                               torch.IntTensor([encoded_len[0]]), torch.IntTensor([len(inds_ctc)])))
            model_estimate_ctc_wide.append(ctc_loss(encoded[i], inds_ctc_wide,\
                                                    torch.IntTensor([encoded_len[0]]), \
                                                    torch.IntTensor([len(inds_ctc_wide)])))
            
            j += 1

    return model_estimate_las, model_estimate_ctc, model_estimate_ctc_wide

0.20093945720250522

In [288]:
ctc_estimate_las, ctc_estimate_ctc, ctc_estimate_ctc_wide = \
ctc_estimate_hyp(conformer_ctc, hyps_las, hyps_ctc, hyps_ctc_wide)

ctc_wide_estimate_las, ctc_wide_estimate_ctc, ctc_wide_estimate_ctc_wide = \
ctc_estimate_hyp(conformer_ctc_wide, hyps_las, hyps_ctc, hyps_ctc_wide)

In [276]:
hyps_mbr = []
softmax = torch.nn.Softmax(dim=0)

las_prob = F.normalize((torch.Tensor([las_estimate_las, las_estimate_ctc, las_estimate_ctc_wide])), p=1.0, dim=0)
ctc_prob = softmax(-torch.Tensor([ctc_estimate_las, ctc_estimate_ctc, ctc_estimate_ctc_wide]))
ctc_wide_prob = softmax(-torch.Tensor([ctc_wide_estimate_las, ctc_wide_estimate_ctc, ctc_wide_estimate_ctc_wide]))

for i in range(len(refs)):
    w = [hyps_las[i].split(' '), hyps_ctc[i].split(' '), hyps_ctc_wide[i].split(' ')]
    w_las = sum([editdistance.eval(w[0], w[j]) * las_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[0], w[j]) * ctc_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[0], w[j]) * ctc_wide_prob[j, i] for j in range(3)])
    w_ctc = sum([editdistance.eval(w[1], w[j]) * las_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[1], w[j]) * ctc_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[1], w[j]) * ctc_wide_prob[j, i] for j in range(3)])
    w_ctc_wide = sum([editdistance.eval(w[2], w[j]) * las_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[2], w[j]) * ctc_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[2], w[j]) * ctc_wide_prob[j, i] for j in range(3)])
    hyps_mbr.append(' '.join(w[torch.argmin(torch.Tensor([w_las, w_ctc, w_ctc_wide]))]))

In [299]:
compute_wer(refs, hyps_mbr)

0.27244258872651356

In [378]:
hyps_mbr = []
softmax = torch.nn.Softmax(dim=0)

las_prob = torch.nn.functional.normalize(torch.Tensor([las_estimate_las, las_estimate_ctc, las_estimate_ctc_wide]), p=1.0, dim=0)
ctc_prob = softmax(-torch.Tensor([ctc_estimate_las, ctc_estimate_ctc, ctc_estimate_ctc_wide]))
ctc_wide_prob = softmax(-torch.Tensor([ctc_wide_estimate_las, ctc_wide_estimate_ctc, ctc_wide_estimate_ctc_wide]))

for i in range(len(refs)):
    w = [hyps_las[i].split(' '), hyps_ctc[i].split(' '), hyps_ctc_wide[i].split(' ')]
    w_las = sum([editdistance.eval(w[0], w[j]) * las_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[0], w[j]) * ctc_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[0], w[j]) * ctc_wide_prob[j, i] for j in range(3)])
    w_ctc = sum([editdistance.eval(w[1], w[j]) * las_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[1], w[j]) * ctc_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[1], w[j]) * ctc_wide_prob[j, i] for j in range(3)])
    w_ctc_wide = sum([editdistance.eval(w[2], w[j]) * las_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[2], w[j]) * ctc_prob[j, i] for j in range(3)]) + \
        sum([editdistance.eval(w[2], w[j]) * ctc_wide_prob[j, i] for j in range(3)])
    hyps_mbr.append(' '.join(w[torch.argmin(torch.Tensor([w_las, w_ctc, w_ctc_wide]))]))

In [379]:
compute_wer(refs, hyps_mbr)

0.3568875193595886

In [383]:
pd.DataFrame({"LAS": compute_wer(refs, hyps_las), "CTC": compute_wer(refs, hyps_ctc), \
             "CTC Wide": compute_wer(refs, hyps_ctc_wide), "ROVER": compute_wer(refs, hyps_rover), \
             "MBR": compute_wer(refs, hyps_mbr)}, index = [0])

Unnamed: 0,LAS,CTC,CTC Wide,ROVER,MBR
0,0.422903,0.425405,0.370949,0.359509,0.356888
