# Beam Search Decoder

### Pseudocode

**Given:**
 - transformer decoder
     - inputs:
         - tgt: input sequence (max_seq_len, batch_size, word_embed + lm_dim)
         - memory: encoder output
         - tgt_mask: attention mask for tgt (max_seq_len, max_seq_len) where position i may attend 0/False values of tgt
         - tgt_key_padding_mask: (batch_size, max_seq_len) - 1/True is ignored
         - memory_key_padding_mask: encoder padding

In [1]:
# Imports
import torch

In [5]:
~torch.tril(torch.ones(5, 5)).bool()

tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])

In [6]:
from datasets import load_metric
metric = load_metric("wer")

ModuleNotFoundError: No module named 'datasets'

In [77]:
from jiwer import wer
import jiwer

In [78]:
ground_truth = ["hello duck", "i like monthy python"]
hypothesis = ["hello duck", "I like, python"]

transformation = jiwer.Compose([
    jiwer.RemoveMultipleSpaces(),
    jiwer.Strip(),
    jiwer.SentencesToListOfWords(),
    jiwer.RemoveEmptyStrings(),
    jiwer.RemovePunctuation(),
    jiwer.ToLowerCase(),
    jiwer.RemoveKaldiNonWords()
])

error = wer(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation)

In [79]:
error

0.16666666666666666

In [80]:
from transformers import GPT2Tokenizer

In [81]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [82]:
tokenizer("hello, my name is")

{'input_ids': [31373, 11, 616, 1438, 318], 'attention_mask': [1, 1, 1, 1, 1]}

In [83]:
tokenizer.pad_token = tokenizer.eos_token
truth = ["hello, my name is", "hello your name?"]
d = tokenizer.batch_encode_plus(
    ["hello, my name is", "hello your name?"],
    padding="longest",
    return_attention_mask=True,
    return_tensors="pt"
)
input_ids = d["input_ids"]
attention_mask = d["attention_mask"]

In [96]:
wers = []
for i, iids in enumerate(input_ids):
    s = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[i, attention_mask[i].bool()]))
    print(s)
    wers.append(wer(truth[i], s, truth_transform=transformation, hypothesis_transform=transformation))

hello, my name is
hello your name?


In [102]:
print(sum(wers) / len(wers))

0.0
