In [1]:
import os
from typing import Iterable, Tuple, List
import warnings

import editdistance

warnings.filterwarnings("ignore")

import torch
import torch.nn.functional as F
import sentencepiece
import omegaconf
import pytorch_lightning as pl
import pandas as pd
from tqdm.auto import tqdm
import numpy as np
import torch.nn as nn

from src.data import TokenToIdEncoder
from src.models import ConformerLAS, ConformerCTC
from src.metrics import WER

In [2]:
def init_model(model: pl.LightningModule, ckpt_path: str) -> pl.LightningModule:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(ckpt)
    model.eval()
    model.freeze()
    return model


def compute_wer(refs: Iterable[str], hyps: Iterable[str]) -> float:
    wer = WER()
    wer.update(refs, hyps)
    return wer.compute()[0].item()


class GreedyDecoderLAS:
    def __init__(self, model: ConformerLAS, max_steps=20):
        self.model = model
        self.max_steps = max_steps

    def __call__(self, encoded: torch.Tensor) -> str:
        tokens = [self.model.decoder.tokenizer.bos_id()]

        for _ in range(self.max_steps):
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            sequence_len = tokens_batch.size(-1)
            att_mask = self.model.make_attention_mask(torch.tensor([sequence_len]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_attention_mask=att_mask, target_pad_mask=None
            )
        
            best_next_token = distribution[0, -1].argmax()
            
            if best_next_token == self.model.decoder.tokenizer.eos_id():
                break

            tokens.append(best_next_token.item())
        
        return self.model.decoder.tokenizer.decode(tokens)

# Single Model

In [3]:
dataset = os.getcwd() + "/../../week07/asr/data/test_opus/farfield/manifest.jsonl"
dataset

'/mnt/c/Users/Kirill/Documents/repos/speech-tech-mipt/week09/asr/../../week07/asr/data/test_opus/farfield/manifest.jsonl'

## LAS

In [4]:
conformer_las_conf = omegaconf.OmegaConf.load("./conf/conformer_las.yaml")
conformer_las_conf.val_dataloader.dataset.manifest_name = dataset
conformer_las_conf.model.decoder.tokenizer = "./data/tokenizer/bpe_1024_bos_eos.model"

conformer_las = init_model(
    model=ConformerLAS(conf=conformer_las_conf),
    ckpt_path="./data/conformer_las_2epochs.ckpt"
)

In [5]:
las_decoder = GreedyDecoderLAS(conformer_las)

refs, hyps_las = [], []

for batch in tqdm(conformer_las.val_dataloader()):
    features, features_len, targets, target_len = batch

    encoded, encoded_len = conformer_las(features, features_len)

    batch_size = features.shape[0]
    for i in range(batch_size):
        encoder_states = encoded[
            [i],
            :encoded_len[i],
            :
        ]

        ref_tokens = targets[i, :target_len[i]].tolist()

        refs.append(
            conformer_las.decoder.tokenizer.decode(ref_tokens)
        )
        hyps_las.append(
            las_decoder(encoder_states)
        )

  0%|          | 0/479 [00:00<?, ?it/s]

In [6]:
wer_las = compute_wer(refs, hyps_las)
wer_las

0.42290276288986206

## CTC

In [7]:
def decode_ctc_hyps(model: ConformerCTC) -> Tuple[List[str], List[str]]:
    refs, hyps = [], []

    for batch in tqdm(model.val_dataloader()):
        features, features_len, targets, target_len = batch
        encoded, encoded_len, preds = model(features, features_len)

        batch_refs = model.decoder.decode(token_ids=targets, token_ids_length=target_len)
        batch_hyps = model.decoder.decode(
            token_ids=preds, token_ids_length=encoded_len, unique_consecutive=True
        )

        refs.extend(
            batch_refs
        )
        hyps.extend(
            batch_hyps
        )

    return refs, hyps

In [8]:
conformer_ctc_conf = omegaconf.OmegaConf.load("./conf/conformer_ctc.yaml")
conformer_ctc_conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc = init_model(
    model=ConformerCTC(conf=conformer_ctc_conf),
    ckpt_path="./data/conformer_7epochs_state_dict.ckpt"
)

In [9]:
refs, hyps_ctc = decode_ctc_hyps(conformer_ctc)

  0%|          | 0/479 [00:00<?, ?it/s]

In [10]:
wer_ctc = compute_wer(refs, hyps_ctc)
wer_ctc

0.4247196316719055

In [11]:
conformer_ctc_wide_conf = omegaconf.OmegaConf.load("./conf/conformer_ctc_wide.yaml")
conformer_ctc_wide_conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc_wide = init_model(
    model=ConformerCTC(conf=conformer_ctc_wide_conf),
    ckpt_path="./data/conformer_wide_7epochs_state_dict.ckpt"
)

In [12]:
refs, hyps_ctc_wide = decode_ctc_hyps(conformer_ctc_wide)

  0%|          | 0/479 [00:00<?, ?it/s]

In [13]:
wer_ctc_wide = compute_wer(refs, hyps_ctc_wide)
wer_ctc_wide

0.3700787425041199

# ROVER: Recognizer Output Voting Error Reduction — 5 points

* [A post-processing system to yield reduced word error rates: Recognizer Output Voting Error Reduction (ROVER)](https://ieeexplore.ieee.org/document/659110)
* [Improved ROVER using Language Model Information](https://www-tlp.limsi.fr/public/asr00_holger.pdf)

Alignment + Voting

![](./images/rover_table.png)

In [14]:
from crowdkit.aggregation.texts import ROVER

Create a DataFrame with columns "task" (one example from the dataset), "worker" (model), "text" (predicted text), "priority" (better model has lower priority):

In [15]:
dataframe_input_las = [(i, "las", hyp, 1) for i, hyp in enumerate(hyps_las)]
dataframe_input_ctc = [(i, "ctc", hyp, 2) for i, hyp in enumerate(hyps_ctc)]
dataframe_input_ctc_wide = [(i, "ctc_wide", hyp, 0) for i, hyp in enumerate(hyps_ctc_wide)]
dataframe_input = [*dataframe_input_las, *dataframe_input_ctc, *dataframe_input_ctc_wide]
aggregate_hyps = pd.DataFrame.from_records(dataframe_input, columns=["task", "worker", "text", "priority"]).sort_values(by=["task", "priority"])
aggregate_hyps.head(15)

Unnamed: 0,task,worker,text,priority
3832,0,ctc_wide,джой хватит,0
0,0,las,джой хватит,1
1916,0,ctc,джой хлатит,2
3833,1,ctc_wide,салют вызов светлане васильевне виколенко,0
1,1,las,салют вызов светлане васильевневый колемка,1
1917,1,ctc,салют вызов светлане васильевне воколенко,2
3834,2,ctc_wide,салют латит,0
2,2,las,салют играть в латис,1
1918,2,ctc,салютсватив,2
3835,3,ctc_wide,джой звонок юрию ивановичу царьклову,0


In [16]:
tokenizer = lambda s: s.split(' ')
detokenizer = lambda tokens: ' '.join(tokens)
hyps_rover = ROVER(tokenizer=tokenizer, detokenizer=detokenizer, silent=False).fit_predict(aggregate_hyps)
hyps_rover.head(5)

  0%|          | 0/1916 [00:00<?, ?it/s]

task
0                                  джой хватит
1    салют вызов светлане васильевне воколенко
2                                  салют латит
3         джой звонок юрью ивановичу царьклову
4                      джой десть десятьценари
Name: agg_text, dtype: object

In [17]:
wer_rover = compute_wer(refs, hyps_rover)
wer_rover

0.35874491930007935

In [18]:
print(f"WER improvement by ROVER: {min(wer_las, wer_ctc, wer_ctc_wide) - wer_rover}")

WER improvement by ROVER: 0.011333823204040527


Aggregated hypotheses by ROVER have lower WER than hypotheses by any of single models.

# MBR: Minimum Bayes Risk — 5 points


* [Minimum Bayes Risk Decoding and System
Combination Based on a Recursion for Edit Distance](https://danielpovey.com/files/csl11_consensus.pdf)
* [mbr-decoding blog-post](https://suzyahyah.github.io/bayesian%20inference/machine%20translation/2022/02/15/mbr-decoding.html)
* [Combination of end-to-end and hybrid models for speech recognition](http://www.interspeech2020.org/uploadfile/pdf/Tue-1-8-4.pdf)

![](./images/mbr_scheme.png)

In [60]:
hyps_by_all_models = np.array(list(zip(hyps_las, hyps_ctc, hyps_ctc_wide)))

In [61]:
batch_size = conformer_las_conf.val_dataloader.batch_size

In [62]:
def calculate_mwer(edit_distances: np.ndarray, hyp_probabilities: np.ndarray) -> float:
    return np.dot(edit_distances, hyp_probabilities)

In [63]:
def get_edit_distance_matrices_for_batch(hyps_batch: List[List[str]]) -> torch.Tensor:
    batch_size = len(hyps_batch)
    models_count = len(hyps_batch[0])
    distance_matrix = torch.empty((batch_size, models_count, models_count))
    for i, hyps_for_single_sample in enumerate(hyps_batch):
        for j in range(len(hyps_for_single_sample)):
            for k in range(len(hyps_for_single_sample)):
                if j < k:
                    hyp1_tokens = hyps_for_single_sample[j].split()
                    hyp2_tokens = hyps_for_single_sample[k].split()
                    dist = editdistance.eval(hyp1_tokens, hyp2_tokens)
                    distance_matrix[i, j, k] = dist
                elif j == k:
                    distance_matrix[i, j, k] = 0
                else:
                    distance_matrix[i, j, k] = distance_matrix[i, k, j]

    return distance_matrix

In [64]:
def normalize_probs(hyps_probs: torch.Tensor):
    return hyps_probs / hyps_probs.sum(dim=-1).unsqueeze(-1)

In [65]:
class LASLikelihoodEstimator:
    def __init__(self, model: ConformerLAS):
        self.model = model
        self.tokenizer: sentencepiece.SentencePieceProcessor = self.model.decoder.tokenizer

    def __call__(self, encoder_result: Tuple[torch.Tensor, torch.Tensor], hyps: List[str]) -> torch.Tensor:
        encoder_state, encoder_state_len = encoder_result

        hyps_tokenized = [
                torch.tensor([self.tokenizer.bos_id()] + self.tokenizer.encode(hyp) + [self.tokenizer.eos_id()], dtype=torch.long)
                for hyp in hyps
        ]
        hyps_lengths = torch.tensor([len(hyp) - 1 for hyp in hyps_tokenized], dtype=torch.long)
        hyps_padded = torch.nn.utils.rnn.pad_sequence(hyps_tokenized, batch_first=True).long()
        hyps_padded_length = hyps_padded.size(-1) - 1

        encoded_pad_mask = self.model.make_pad_mask(encoder_state_len)
        target_pad_mask = self.model.make_pad_mask(hyps_lengths)
        target_att_mask = self.model.make_attention_mask(torch.tensor([hyps_padded_length]))

        logits = self.model.decoder(
            encoded=encoder_state, encoded_pad_mask=~encoded_pad_mask,
            target=hyps_padded[:, :-1], target_attention_mask=target_att_mask, target_pad_mask=~target_pad_mask
        )

        # select logits corresponding to tokens in hyps
        hyps_tokens_logits = torch.gather(input=logits, dim=2, index=hyps_padded[:, 1:].unsqueeze(-1)).squeeze(-1)
        hyps_probs = torch.prod(torch.exp(hyps_tokens_logits) / torch.sum(torch.exp(logits), dim=2) + ~target_pad_mask, dim=1)

        return hyps_probs

In [66]:
class CTCLikelihoodEstimator:
    def __init__(self, model_config, model: ConformerCTC):
        self.model = model
        self.tokenizer = lambda sentence: list(sentence)
        self.token_to_id_encoder = TokenToIdEncoder.from_model_config(model_config.model)

    def __call__(self, encoder_result: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], hyps: List[str]) -> torch.Tensor:
        logprobs, encoded_len, preds = encoder_result

        hyps_tokenized = [
            torch.tensor([self.token_to_id_encoder.encode(token) for token in self.tokenizer(hyp)], dtype=torch.long)
            for hyp in hyps
        ]
        hyps_lengths = torch.tensor([len(hyp) for hyp in hyps_tokenized], dtype=torch.long)
        hyps_padded = torch.nn.utils.rnn.pad_sequence(hyps_tokenized, batch_first=True).long()

        loss = self.model.ctc_loss(
            logprobs.transpose(1, 0), hyps_padded, encoded_len, hyps_lengths
        )
        hyps_probs = torch.exp(-loss)

        return hyps_probs

In [None]:
model_weights = torch.tensor([wer_las, wer_ctc, wer_ctc_wide]) / (wer_las + wer_ctc + wer_ctc_wide)

In [67]:
las_likelihood_estimator = LASLikelihoodEstimator(conformer_las)
ctc_likelihood_estimator = CTCLikelihoodEstimator(conformer_ctc_conf, conformer_ctc)
ctc_wide_likelihood_estimator = CTCLikelihoodEstimator(conformer_ctc_wide_conf, conformer_ctc_wide)

mbr_dataloader = conformer_las.val_dataloader()
mbr_tokenizer = conformer_las.decoder.tokenizer

ref, hyps_mbr = [], []

for batch_index, batch in enumerate(tqdm(mbr_dataloader)):
    features, features_len, target, target_len = batch

    hyps_batch = hyps_by_all_models[batch_index * batch_size: (batch_index + 1) * batch_size]
    hyps_las_batch = hyps_batch[:, 0]
    hyps_ctc_batch = hyps_batch[:, 1]
    hyps_ctc_wide_batch = hyps_batch[:, 2]

    las_encoder_result = conformer_las(features, features_len)
    ctc_encoder_result = conformer_ctc(features, features_len)
    ctc_wide_encoder_result = conformer_ctc_wide(features, features_len)

    distance_matrices = get_edit_distance_matrices_for_batch(hyps_batch)

    las_las_likelihood = las_likelihood_estimator(las_encoder_result, hyps_las_batch)
    las_ctc_likelihood = las_likelihood_estimator(las_encoder_result, hyps_ctc_batch)
    las_ctc_wide_likelihood = las_likelihood_estimator(las_encoder_result, hyps_ctc_wide_batch)

    ctc_las_likelihood = ctc_likelihood_estimator(ctc_encoder_result, hyps_las_batch)
    ctc_ctc_likelihood = ctc_likelihood_estimator(ctc_encoder_result, hyps_ctc_batch)
    ctc_ctc_wide_likelihood = ctc_likelihood_estimator(ctc_encoder_result, hyps_ctc_wide_batch)

    ctc_wide_las_likelihood = ctc_wide_likelihood_estimator(ctc_wide_encoder_result, hyps_las_batch)
    ctc_wide_ctc_likelihood = ctc_wide_likelihood_estimator(ctc_wide_encoder_result, hyps_ctc_batch)
    ctc_wide_ctc_wide_likelihood = ctc_wide_likelihood_estimator(ctc_wide_encoder_result, hyps_ctc_wide_batch)

    las_hyp_likelihood = normalize_probs(torch.concat((las_las_likelihood.unsqueeze(-1), las_ctc_likelihood.unsqueeze(-1), las_ctc_wide_likelihood.unsqueeze(-1)), dim=-1))
    ctc_hyp_likelihood = normalize_probs(torch.concat((ctc_las_likelihood.unsqueeze(-1), ctc_ctc_likelihood.unsqueeze(-1), ctc_ctc_wide_likelihood.unsqueeze(-1)), dim=-1))
    ctc_wide_hyp_likelihood = normalize_probs(torch.concat((ctc_wide_las_likelihood.unsqueeze(-1), ctc_wide_ctc_likelihood.unsqueeze(-1), ctc_wide_ctc_wide_likelihood.unsqueeze(-1)), dim=-1))

    hyps_likelihood = torch.concat((las_hyp_likelihood, ctc_hyp_likelihood, ctc_wide_hyp_likelihood), dim=-1).view(-1, 3, 3)

    mwer_scores = []
    for element_in_batch_index in range(len(batch)):
        hyp_distances_for_sample = distance_matrices[element_in_batch_index]
        hyps_likelihood_for_sample = hyps_likelihood[element_in_batch_index]

        mwer_score_las = ((hyp_distances_for_sample[0] * hyps_likelihood_for_sample).sum(dim=-1) * model_weights).sum()
        mwer_score_ctc = ((hyp_distances_for_sample[1] * hyps_likelihood_for_sample).sum(dim=-1) * model_weights).sum()
        mwer_score_ctc_wide = ((hyp_distances_for_sample[2] * hyps_likelihood_for_sample).sum(dim=-1) * model_weights).sum()
        mwer_scores.append(torch.tensor([mwer_score_las, mwer_score_ctc, mwer_score_ctc_wide]).view(1, -1))

    mwer_scores_tensor = torch.concat(mwer_scores)
    top_hyp_per_input = mwer_scores_tensor.argmin(dim=-1)

    refs.extend(
        [mbr_tokenizer.decode(target_tokens[:target_len[j]].tolist()) for j, target_tokens in enumerate(target)]
    )
    hyps_mbr.extend(
        [hyps_batch[j][top_hyp_per_input[j]] for j in range(len(hyps_batch))]
    )

  0%|          | 0/479 [00:00<?, ?it/s]

In [68]:
wer_mbr = compute_wer(refs, hyps_mbr)
wer_mbr

0.34156525135040283

In [69]:
print(f"WER improvement by MBR: {min(wer_las, wer_ctc, wer_ctc_wide) - wer_mbr}")

WER improvement by MBR: 0.02851349115371704


Aggregated hypotheses by MBR have lower WER than hypotheses by any of the single models and hypotheses obtained by applying ROVER.

In [70]:
pd.DataFrame.from_records(list(zip(hyps_mbr, refs))[:15], columns=["Hypothesis from MBR", "Reference"])

Unnamed: 0,Hypothesis from MBR,Reference
0,джой хватит,джой хватит
1,салют вызов светлане васильевне виколенко,салют вызов светлане васильевне николенко
2,салют латит,салют хватит
3,джой звонок юрью ивановичу царькову,джой звонок юрию ивановичу царькову
4,джой быдь десятьценари,джой выйти из сценария
5,салют вэйте,салют выйти
6,салют закройся,салют закройся
7,салют набери данилова,салют набери данилова
8,сбер мне нравится,сбер мне нравится
9,салют прекрати,салют прекрати
