[Transformer Interpretation](https://arxiv.org/pdf/1906.02762.pdf)            |  [Macaron Net](https://arxiv.org/pdf/1906.02762.pdf)
:-------------------------:|:-------------------------:
![](./images/particle_dynamics.png)  |  ![](./images/macaron_net.png)

[Conformer](https://arxiv.org/pdf/2005.08100.pdf)            |  [Transformer](https://arxiv.org/pdf/1706.03762.pdf)
:-------------------------:|:-------------------------:
![](./images/conformer.png)  |  ![](./images/transformer.png)

In [None]:
import heapq

import torch
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
import sentencepiece
import omegaconf

from src.models import ConformerLAS

# LAS: Listen Attend and Spell

## dataset

In [None]:
df = pd.read_json("data/test_opus/crowd/manifest.jsonl", lines=True)

In [None]:
df.head()

In [None]:
df.text.str.split().explode().value_counts().head(20)

In [None]:
df.text.str.len().hist(bins=75);

In [None]:
df.duration.hist(bins=100);

## tokenizer

In [None]:
tokenizer = sentencepiece.SentencePieceProcessor(model_file="data/tokenizer/bpe_1024_bos_eos.model")

In [None]:
tokenizer.encode_as_pieces('мама мыла раму')

In [None]:
tokenizer.encode_as_ids('мама мыла раму')

In [None]:
df.text.apply(lambda x: len(tokenizer.encode(x))).hist(bins=75);

In [None]:
tokenizer.bos_id(), tokenizer.eos_id()

## model

In [None]:
conf = omegaconf.OmegaConf.load("./conf/conformer_las.yaml")
conf.train_dataloader.batch_size = 4
conf.train_dataloader.num_workers = 4
conf.model.decoder.tokenizer = "./data/tokenizer/bpe_1024_bos_eos.model"


model = ConformerLAS(conf=conf)
ckpt = torch.load("data/conformer_las_2epochs.ckpt", map_location="cpu")
model.load_state_dict(ckpt)
model.eval()
model.freeze()

## features

In [None]:
batch = next(iter(model.val_dataloader()))

features, features_len, targets, target_len = batch

for feature in features:
    plt.imshow(feature)
    plt.show()

## loss

In [None]:
encoded, encoded_len = model(features, features_len)


encoded_pad_mask = model.make_pad_mask(encoded_len)

targets_outputs = targets[:, 1:] # without bos
targets_inputs = targets[:, :-1] # without eos / last pad token
target_len -= 1

target_pad_mask = model.make_pad_mask(target_len)
target_mask = model.make_attention_mask(target_len)

logits = model.decoder(encoded, ~encoded_pad_mask, targets_inputs, target_mask, ~target_pad_mask)

loss = model.loss(logits.transpose(1, 2), targets_outputs)
plt.imshow(loss)
plt.colorbar(fraction=0.01)
plt.show()
plt.imshow(loss * target_pad_mask)
plt.colorbar(fraction=0.01)
plt.show()

In [None]:
(loss * target_pad_mask).sum() / target_pad_mask.sum()

## greedy decoding

In [None]:
class GreedyDecoder:
    def __init__(self, model, tokenizer, max_steps=20):
        self.model = model
        self.max_steps = max_steps

    def __call__(self, encoded):
        
        tokens = [self.model.decoder.tokenizer.bos_id()]

        for _ in range(self.max_steps):
            
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )
        
            best_next_token = distribution[0, -1].argmax()
            
            if best_next_token == self.model.decoder.tokenizer.eos_id():
                break

            tokens.append(best_next_token.item())
        
        return self.model.decoder.tokenizer.decode(tokens)

In [None]:
batch = next(iter(model.val_dataloader()))

features, features_len, targets, target_len = batch

encoded, encoded_len = model(features, features_len)

In [None]:
decoder = GreedyDecoder(model, tokenizer)


for i in range(features.shape[0]):

    encoder_states = encoded[
        [i],
        :encoded_len[i],
        :
    ]
    
    ref_tokens = targets[i, :target_len[i]].tolist()
    
    print(f"reference : {tokenizer.decode(ref_tokens)}")
    print(f"hypothesis: {decoder(encoder_states)}")
    print("#" * 100)

## beam search decoding

In [None]:
class BeamSearchDecoder:
    
    def __init__(self, model, temp=1.0, beam_size=5, max_steps=20):
        self.model = model
        self.temp = temp
        self.beam_size = beam_size
        self.max_steps = max_steps
        
    def __call__(self, encoded):
        
        partial_hyps= [(0.0, [self.model.decoder.tokenizer.bos_id()])]
        final_hyps = []

        while len(partial_hyps) > 0:
            
            cur_partial_score, cur_partial_hyp = heapq.heappop(partial_hyps)
            
            tokens_batch = torch.tensor(cur_partial_hyp).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            logits = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )

            logprobs = F.log_softmax(logits[0, -1] / self.temp, dim=-1)
            
            candidates = logprobs.topk(self.beam_size)
            
            for token_score, token_idx in zip(candidates.values, candidates.indices):
                
                token_idx = int(token_idx)

                new_score = cur_partial_score - float(token_score)
                new_hyp = cur_partial_hyp + [token_idx]
                new_item = (new_score, new_hyp)

                if token_idx == self.model.decoder.tokenizer.eos_id() or len(new_hyp) - 1 >= self.max_steps:
                    final_hyps.append(new_item)
                else:
                    heapq.heappush(partial_hyps, new_item)
            
            if len(partial_hyps) > self.beam_size:
                partial_hyps = heapq.nsmallest(self.beam_size, partial_hyps)
                heapq.heapify(partial_hyps)

        final_scores, final_token_lists = zip(*final_hyps)
        
        final_texts = self.model.decoder.tokenizer.decode(final_token_lists)

        result = list(zip(final_scores, final_texts))
        result.sort()

        return result[:self.beam_size]

In [None]:
decoder = BeamSearchDecoder(model, temp=1, beam_size=10)

for i in range(features.shape[0]):

    encoder_states = encoded[
        [i],
        :encoded_len[i],
        :
    ]
    
    ref_tokens = targets[i, :target_len[i]].tolist()
    
    print(f"reference   : {tokenizer.decode(ref_tokens)}")
    for k, (score, hyp) in enumerate(decoder(encoder_states)):
        print(f"hypothesis {k + 1}: {hyp} {score:.2f}")
    print("#" * 100)