# Goal of this notebook

Develop a training loop for finetuning ASR models using TTS loss by recreating RL training found in RL4LMs/rl4lms/envs/text_generation/training_utils.py

# automatic reloading magic

# imports

In [1]:
import torch
from typing import List, Dict, Tuple, Any
import hyperpyyaml
from tqdm import tqdm
from torchaudio.models.decoder import ctc_decoder
from torch.nn.functional import softmax
import random
from jiwer import cer
import numpy as np
import speechbrain as sb

## check if gpu available

In [2]:
# print hostname to make sure we are on correct node
import socket
print(socket.gethostname())

strickland.inf.ed.ac.uk


In [3]:
torch.cuda.is_available()

True

In [4]:
import os
os.getcwd()

'/disk/nfs/ostrom/s1785140/rlspeller'

# HPARAMS

In [5]:
hparams = {
    "softdtw_temp": 0.01,
    "softdtw_bandwidth": 120,
    "dist_func": "l1",
    "sentencepiece_model_path": "/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/0_char.model",
    'speechbrain_hparams_file': '/home/s1785140/rlspeller/infer_speechbrain.yaml',
}

# TOKENIZER

In [6]:
# load pretrained tokenizer used to tokenizer ASR training inputs 
import sentencepiece as spm 
spm_path = hparams["sentencepiece_model_path"]
sp = spm.SentencePieceProcessor()
sp.load(spm_path)
print(sp.vocab_size())

28


In [7]:
# test tokenizer
s = "hello world my name is jason"
# TODO pass string through text cleaners? 
encoded = sp.EncodeAsIds(s)
assert 0 not in encoded, "tried to encode an unknown character"
print(" ".join(str(idx) for idx in encoded))

1 10 2 12 12 4 1 17 4 9 12 11 1 16 20 1 6 5 16 2 1 7 8 1 26 5 8 4 6


In [8]:
sp.DecodeIds(encoded)

'hello world my name is jason'

# NEW! SIMPLE TOKENIZER

In [9]:
from speechbrain.tokenizers.SimpleTokenizer import SimpleTokenizer

In [10]:
tokenizer = SimpleTokenizer()

In [11]:
text = "hello my name is jason"
text = text.replace(' ', '|')
print(text)
ids = tokenizer.encode_as_ids(text)
ids

hello|my|name|is|jason


[9, 6, 13, 13, 16, 1, 14, 26, 1, 15, 2, 14, 6, 1, 10, 20, 1, 11, 2, 20, 16, 15]

In [12]:
tokenizer.decode_ids(ids)

'hello|my|name|is|jason'

## test simple tokenizer with probability distribution, and see if CTC decoder successfully generates n-best lists

In [13]:
# create empty array of correct dimensions
min_len, max_len = 50, 100
bsz = 4
lens = torch.randint(min_len, max_len, (bsz,))
vocab_size = len(tokenizer.vocab)

# randomly assign probaility distribution to each timestep

# try to decode

In [14]:
randn = torch.randn(bsz, max_len, vocab_size)

In [15]:
ctc_probs = softmax(randn, dim=1)
# ctc_probs

In [16]:
ctc_beamsearch_decoder_test = ctc_decoder(
    lexicon=None,
    # tokens="/home/s1785140/rlspeller/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/tokens.txt",
    tokens=tokenizer.vocab,
    nbest=2,
    blank_token='-',
    sil_token="|",
)

predicted_ids = ctc_beamsearch_decoder_test(ctc_probs, lens)

predicted_words = []
for i, hyps in enumerate(predicted_ids):
    for j, hyp in enumerate(hyps):
        words = tokenizer.decode_ids(hyp.tokens.tolist()).split(" ")
        tup = (f"sample {i+1}, hyp {j+1}/{len(hyps)}", words)
        predicted_words.append(tup)
        print(tup)

('sample 1, hyp 1/2', ['|ad|o|zxhqxo|uaceqvsymywhxzniexgxdubvuoxwc|kathjsin|hncdkapmebhwrofrgmy|'])
('sample 1, hyp 2/2', ['|ad|o|zxhqxo|uaceqvsymywhxzniexgxdubvuoxwc|kathjsin|hncdktapmebhwrofrgmy|'])
('sample 2, hyp 1/2', ['|mlyehfvrmucpowdskuqrzucxnlseau|dbtjrjihqmtgvauxvbtnudvlyatydrbqwqbnfwcneokuvc|esbhmbzpqsaetmn|'])
('sample 2, hyp 2/2', ['|mlyehfvrmucpowdskuqrzucxnlseau|dbtjrjihqmtgvauxvbtnudvlyatydrbqwqbnfwcnzokuvc|esbhmbzpqsaetmn|'])
('sample 3, hyp 1/2', ['|rgjycvzuetizuoltod|dgdalhofnldiwvuyzdjnsomwbzcaz|qu|cgqhaudyfbwxknczqknmsen|'])
('sample 3, hyp 2/2', ['|rgjucvzuetizuoltod|dgdalhofnldiwvuyzdjnsomwbzcaz|qu|cgqhaudyfbwxknczqknmsen|'])
('sample 4, hyp 1/2', ['|viqapzkbxszivnsehzjfwj|neydugmdkrkavcj|htxultzkavdfcxplrgdsbiwffqxyjhl|'])
('sample 4, hyp 2/2', ['|viqapzkbxszivnsehzjfwj|neydugmdkrkavcj|htxultzkavdfcxplrgdsbimffqxyjhl|'])


# LOAD ASR (PRETRAINED)

In [17]:
from templates.speech_recognition_CharTokens_NoLM.ASR.train import ASR
from templates.speech_recognition_CharTokens_NoLM.ASR.train import dataio_prepare
from torch.utils.data import DataLoader
from speechbrain.dataio.dataloader import LoopedLoader

In [18]:
# Load hyperparameters file with command-line overrides
speechbrain_hparams_file = hparams['speechbrain_hparams_file']
with open(speechbrain_hparams_file) as f:
    speechbrain_hparams = hyperpyyaml.load_hyperpyyaml(f)

/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/rirs_noises.zip exists. Skipping download


In [19]:
speechbrain_hparams['save_folder']

'/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/ASR/results/CRDNN_CHAR_LJSpeech_halved/2602/save'

In [20]:
# initialise trainer (we don't want to train, but model is tightly coupled with trainer)
asr_brain = ASR(
    modules=speechbrain_hparams["modules"],
    opt_class=speechbrain_hparams["opt_class"],
    hparams=speechbrain_hparams,
    checkpointer=speechbrain_hparams["checkpointer"],
)

def setup_asr_brain_for_infer(asr_brain):
    asr_brain.on_evaluate_start(min_key="WER") # We call the on_evaluate_start that will load the best model
    asr_brain.modules.eval() # We set the model to eval mode (remove dropout etc)

print("if on_evaluate_start() get runtime error, likely need to restart notebook kernel")
setup_asr_brain_for_infer(asr_brain)

if on_evaluate_start() get runtime error, likely need to restart notebook kernel


In [21]:
# create dataset and dataloader for inference
datasets = dataio_prepare(speechbrain_hparams)

test_set = datasets['test']

if not isinstance(test_set, DataLoader) or isinstance(test_set, LoopedLoader):
    test_loader_kwargs=speechbrain_hparams["test_dataloader_opts"]
    test_set = asr_brain.make_dataloader(
        test_set, stage=sb.Stage.TEST, **test_loader_kwargs
    )

In [22]:
# get vocab from tokenizer (needed for ctc decoding)
vocab_size = len(asr_brain.hparams.tokenizer)
vocab = []
for i in range(vocab_size):
    vocab.append(asr_brain.hparams.tokenizer.decode_ids([i]))
print(vocab)

# edit vocab to match default ctc decoder symbols for blank and silence
vocab[0] = '-'
vocab[1] = "|"

print(vocab)

[' ⁇ ', '', 'e', 't', 'o', 'a', 'n', 'i', 's', 'r', 'h', 'd', 'l', 'c', 'f', 'u', 'm', 'w', 'p', 'g', 'y', 'b', 'v', 'k', 'x', 'q', 'j', 'z']
['-', '|', 'e', 't', 'o', 'a', 'n', 'i', 's', 'r', 'h', 'd', 'l', 'c', 'f', 'u', 'm', 'w', 'p', 'g', 'y', 'b', 'v', 'k', 'x', 'q', 'j', 'z']


In [23]:
ctc_beamsearch_decoder = ctc_decoder(
    lexicon=None,
    # tokens="/home/s1785140/rlspeller/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/tokens.txt",
    tokens=vocab,
    nbest=100,
    blank_token='-',
    sil_token="|",
)

In [24]:
# generate transcriptions for all batches in test set
def transcribe_dataset(asr_brain, dataset, greedy=False, num_batches_to_transcribe=None):
    # Now we iterate over the dataset and we simply compute_forward and decode
    with torch.no_grad():
        transcripts = []
        for batch in tqdm(list(dataset)[:num_batches_to_transcribe], dynamic_ncols=True):
            orig_transcriptions = batch.words

            # Make sure that your compute_forward returns the predictions !!!
            # In the case of the template, when stage = TEST, a beam search is applied 
            # in compute_forward(). 
            predictions = asr_brain.compute_forward(batch, stage=sb.Stage.TEST)
            
            ctc_probs = predictions['ctc_logprobs'] # FOR DEBUG

            if greedy:
                predicted_ids = sb.decoders.ctc_greedy_decode(
                    predictions["ctc_logprobs"], asr_brain.feat_lens, blank_id=asr_brain.hparams.blank_index
                )
                predicted_words = [
                    asr_brain.tokenizer.decode_ids(ids).split(" ")
                    for ids in predicted_ids
                ]
            else:
                # get mel lens from wav len ratios since torch ctc decoder requires lens in frames
                batch_max_len = predictions["ctc_logprobs"].size(1)
                bsz = predictions["ctc_logprobs"].size(0)
                mel_lens = torch.zeros(bsz)
                for i, len_ratio in enumerate(asr_brain.feat_lens):
                    mel_lens[i] = int(torch.round(len_ratio * batch_max_len))
                
                predicted_ids = ctc_beamsearch_decoder(
                    predictions["ctc_logprobs"], lengths=mel_lens
                )

                predicted_words = []
                for i, (utt_id, orig_text, hyps) in enumerate(zip(batch.utt_id, orig_transcriptions, predicted_ids)):
                    print(f"\nsample {i+1} - ({utt_id}: '{orig_text}')")
                    sample_cers = []
                    for j, hyp in enumerate(hyps):
                        words = asr_brain.hparams.tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        # words = tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        hyp_cer = 100 * cer(orig_text, words)
                        sample_cers.append(hyp_cer)
                        print(f"\thyp {j+1}/{len(hyps)} (CER={hyp_cer:.1f}%): '{words}'")
                        predicted_words.append((f"sample {i+1}, hyp {j+1}/{len(hyps)}", words))
                        
                    print(f"\t=== Mean CER: {np.mean(sample_cers):.1f}%, Std CER: {np.std(sample_cers):.1f}% ===")

            transcripts.append(predicted_words)

    return transcripts, ctc_probs

transcripts, ctc_probs = transcribe_dataset(asr_brain, test_set, greedy=False, num_batches_to_transcribe=1)

  0%|                                                               | 0/1 [00:00<?, ?it/s]

DEBUG batch: <speechbrain.dataio.batch.PaddedBatch object at 0x7f2246bc0b80>
DEBUG use mel inputs: False
DEBUG INSIDE PREPARE FEATURES, feats.shape=torch.Size([4, 627, 40]) wav_lens.shape=torch.Size([4])


100%|███████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.94s/it]


sample 1 - (LJ039-0175: 'for the first four attempts the firers missed the second shot by several inches')
	hyp 1/48 (CER=0.0%): 'for the first four attempts the firers missed the second shot by several inches '
	hyp 2/48 (CER=1.3%): 'for the first four attempts the firers mised the second shot by several inches '
	hyp 3/48 (CER=1.3%): 'for the fist four attempts the firers missed the second shot by several inches '
	hyp 4/48 (CER=1.3%): 'for the first four attempths the firers missed the second shot by several inches '
	hyp 5/48 (CER=2.5%): 'for the first four attempts the firerers missed the second shot by several inches '
	hyp 6/48 (CER=2.5%): 'for the fist four attempts the firers mised the second shot by several inches '
	hyp 7/48 (CER=1.3%): 'fr the first four attempts the firers missed the second shot by several inches '
	hyp 8/48 (CER=1.3%): 'for the first four attempts the firers  missed the second shot by several inches '
	hyp 9/48 (CER=1.3%): 'for the first four atempts the




# LOAD WORD ALIGNED WAVS into dataset

In [36]:
# imitate CLAs
import sys
import argparse
import math
import glob
from tqdm import tqdm

In [37]:
# set these in yaml config!
train_annotation_path = '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/respeller_train_wordtoken_annotation.json'
valid_annotation_path = '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/respeller_valid_wordtoken_annotation.json'
test_annotation_path = '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/respeller_test_wordtoken_annotation.json'

In [38]:
speechbrain_hparams['train_annotation'] = '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/respeller_train_wordtoken_annotation.json'
speechbrain_hparams['valid_annotation'] = '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/respeller_valid_wordtoken_annotation.json'
speechbrain_hparams['test_annotation'] = '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/respeller_test_wordtoken_annotation.json'

In [45]:
def dataio_prepare(hparams):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined functions.


    Arguments
    ---------
    hparams : dict
        This dictionary is loaded from the `train.yaml` file, and it includes
        all the hyperparameters needed for dataset construction and loading.

    Returns
    -------
    datasets : dict
        Dictionary containing "train", "valid", and "test" keys that correspond
        to the DynamicItemDataset objects.
    """
    # Define audio pipeline. In this case, we simply read the path contained
    # in the variable wav with the audio reader.
    # wav path like: data/ljspeech_wavs_16khz_word_aligned/differs/differs__LJ001-0001__occ1__len8320.wav
    @sb.utils.data_pipeline.takes("wav")
    @sb.utils.data_pipeline.provides("sig", "wav_path", "utt_id")
    def audio_pipeline(wav_path):
        """Load the audio signal. This is done on the CPU in the `collate_fn`."""
        sig = sb.dataio.dataio.read_audio(wav_path)
        yield sig

        yield wav_path

        utt_id = wav_path.split("/")[-1].split(".")[0]
        yield utt_id

    @sb.utils.data_pipeline.takes("samples_to_graphemes_ratio")
    @sb.utils.data_pipeline.provides("samples_to_graphemes_ratio")
    def ratio_pipeline(samples_to_graphemes_ratio):
        yield samples_to_graphemes_ratio

    @sb.utils.data_pipeline.takes("length")
    @sb.utils.data_pipeline.provides("length")
    def length_pipeline(length):
        yield length

    @sb.utils.data_pipeline.takes("words")
    @sb.utils.data_pipeline.provides("words")
    def text_pipeline(words):
        """Processes the transcriptions to generate proper labels

        NB Make sure that you yield exactly what is defined above in @sb.utils.data_pipeline.provides()"""
        yield words

        # TODO also yield mel for calculating fastpitch softdtw loss

    # Define datasets from json data manifest file
    # Define datasets sorted by ascending lengths for efficiency
    datasets = {}
    data_info = {
        "train": hparams["train_annotation"],
        "valid": hparams["valid_annotation"],
        "test": hparams["test_annotation"],
    }

    for split in data_info:
        datasets[split] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=data_info[split],
            dynamic_items=[audio_pipeline, ratio_pipeline, length_pipeline, text_pipeline],
            output_keys=[
                "id",
                "sig",
                "wav_path",
                "utt_id",
                "samples_to_graphemes_ratio",
                "length",
                "words",
            ],
        )
        hparams[f"{split}_dataloader_opts"]["shuffle"] = True

    # TODO uncomment this!!!

    # # Sorting training data with ascending order makes the code  much
    # # faster  because we minimize zero-padding. In most of the cases, this
    # # does not harm the performance.
    # if hparams["sorting"] == "ascending":
    #     datasets["train"] = datasets["train"].filtered_sorted(sort_key="length")
    #     hparams["train_dataloader_opts"]["shuffle"] = False

    # elif hparams["sorting"] == "descending":
    #     datasets["train"] = datasets["train"].filtered_sorted(
    #         sort_key="length", reverse=True
    #     )
    #     hparams["train_dataloader_opts"]["shuffle"] = False

    # elif hparams["sorting"] == "random":
    #     hparams["train_dataloader_opts"]["shuffle"] = True
    #     pass

    # else:
    #     raise NotImplementedError(
    #         "sorting must be random, ascending or descending"
    #     )
    
    return datasets

datasets = dataio_prepare(speechbrain_hparams)

In [46]:
# convert from datasets to dataloaders
split2stage = {"train": sb.Stage.TRAIN, "valid": sb.Stage.VALID, "test": sb.Stage.TEST}
for split in ["train", "valid", "test"]:
    if not isinstance(datasets[split], DataLoader) or isinstance(datasets[split], LoopedLoader):
        dataloader_kwargs=speechbrain_hparams[f"{split}_dataloader_opts"]
        datasets[split] = asr_brain.make_dataloader(
            datasets[split], stage=split2stage[split], **dataloader_kwargs
        )

In [61]:
def set_whitespace_to_0_probability(ctc_probs, vocab, log_probs=True, whitespace_symbol="|"):
    """ctc_probs [bsz, max_seq_len, vocab_size]"""
    new_probability = -math.inf if log_probs else 0.0
    whitespace_index = vocab.index(whitespace_symbol)
    ctc_probs[:,:,whitespace_index] = new_probability
    return ctc_probs

In [67]:
# generate transcriptions for all batches in test set
def transcribe_dataset(asr_brain, dataset, 
                       greedy=False, num_batches_to_transcribe=None,
                       hack_whitespace_probs=False, collapse_whitespace=True):
    # Now we iterate over the dataset and we simply compute_forward and decode
    with torch.no_grad():
        n = 0 # number of batches transcribed
        for batch in tqdm(dataset, dynamic_ncols=True):
            # break out of loop if we have transcribed enough batches
            if n >= num_batches_to_transcribe:
                break
            n += 1

            orig_transcriptions = batch.words

            # Make sure that your compute_forward returns the predictions !!!
            # In the case of the template, when stage = TEST, a beam search is applied 
            # in compute_forward(). 
            predictions = asr_brain.compute_forward(batch, stage=sb.Stage.TEST)
            
            ctc_probs = predictions['ctc_logprobs'] # FOR DEBUG

            # hack probabilities to set all probs to 0 for whitespace
            if hack_whitespace_probs:
                ctc_probs = set_whitespace_to_0_probability(ctc_probs, vocab, log_probs=True, whitespace_symbol="|")

            if greedy:
                predicted_ids = sb.decoders.ctc_greedy_decode(
                    predictions["ctc_logprobs"], asr_brain.feat_lens, blank_id=asr_brain.hparams.blank_index
                )
                predicted_words = [
                    asr_brain.tokenizer.decode_ids(ids).split(" ")
                    for ids in predicted_ids
                ]
            else:
                # get mel lens from wav len ratios since torch ctc decoder requires lens in frames
                batch_max_len = predictions["ctc_logprobs"].size(1)
                bsz = predictions["ctc_logprobs"].size(0)
                mel_lens = torch.zeros(bsz)
                for i, len_ratio in enumerate(asr_brain.feat_lens):
                    mel_lens[i] = int(torch.round(len_ratio * batch_max_len))
                
                predicted_ids = ctc_beamsearch_decoder(
                    predictions["ctc_logprobs"], lengths=mel_lens
                )

                # iterate over samples in batch
                for i, (utt_id, orig_text, hyps) in enumerate(zip(batch.utt_id, orig_transcriptions, predicted_ids)):
                    print(f"\nsample {i+1} - ({utt_id}: '{orig_text}')")
                    sample_cers = []
                    for j, hyp in enumerate(hyps):
                        words = asr_brain.hparams.tokenizer.decode_ids(hyp.tokens.tolist())
                        if collapse_whitespace:
                            words = "".join(c for c in words if c != " ")
                        hyp_cer = 100 * cer(orig_text, words)
                        sample_cers.append(hyp_cer)
                        print(f"\thyp {j+1}/{len(hyps)} (CER={hyp_cer:.1f}%): '{words}'")
                        
                    print(f"\t=== Mean CER: {np.mean(sample_cers):.1f}%, Std CER: {np.std(sample_cers):.1f}% ===")

    output_dict = {
        "words": words,
        "wavs": wavs,
    }

    return output_dict

transcription_output_dict = transcribe_dataset(asr_brain, datasets["train"], 
                                            greedy=False, num_batches_to_transcribe=10)

  0%|                                                            | 0/2864 [00:00<?, ?it/s]

DEBUG batch: <speechbrain.dataio.batch.PaddedBatch object at 0x7f2251149ee0>
DEBUG use mel inputs: False
DEBUG INSIDE PREPARE FEATURES, feats.shape=torch.Size([4, 62, 40]) wav_lens.shape=torch.Size([4])


  0%|                                                    | 1/2864 [00:00<19:19,  2.47it/s]


sample 1 - (offering__LJ044-0102__occ1__len8000: 'offering')
	hyp 1/49 (CER=12.5%): 'offerin'
	hyp 2/49 (CER=25.0%): 'oferin'
	hyp 3/49 (CER=25.0%): 'ofterin'
	hyp 4/49 (CER=25.0%): 'offeren'
	hyp 5/49 (CER=25.0%): 'offein'
	hyp 6/49 (CER=25.0%): 'offern'
	hyp 7/49 (CER=25.0%): 'offeryn'
	hyp 8/49 (CER=25.0%): 'fferin'
	hyp 9/49 (CER=25.0%): 'offrin'
	hyp 10/49 (CER=37.5%): 'oferen'
	hyp 11/49 (CER=25.0%): 'oftferin'
	hyp 12/49 (CER=37.5%): 'ofein'
	hyp 13/49 (CER=37.5%): 'ofern'
	hyp 14/49 (CER=37.5%): 'oferyn'
	hyp 15/49 (CER=25.0%): 'offeri'
	hyp 16/49 (CER=25.0%): 'afferin'
	hyp 17/49 (CER=12.5%): 'offerig'
	hyp 18/49 (CER=37.5%): 'ferin'
	hyp 19/49 (CER=25.0%): 'efferin'
	hyp 20/49 (CER=37.5%): 'oerin'
	hyp 21/49 (CER=25.0%): 'offeriy'
	hyp 22/49 (CER=37.5%): 'ofteren'
	hyp 23/49 (CER=37.5%): 'oftein'
	hyp 24/49 (CER=37.5%): 'oftern'
	hyp 25/49 (CER=25.0%): 'offterin'
	hyp 26/49 (CER=37.5%): 'oferi'
	hyp 27/49 (CER=25.0%): 'offerie'
	hyp 28/49 (CER=37.5%): 'ofteryn'
	hyp 29/49 (C

  0%|                                                    | 2/2864 [00:00<22:00,  2.17it/s]


sample 1 - (youths__LJ003-0094__occ1__len8320: 'youths')
	hyp 1/47 (CER=16.7%): 'yuths'
	hyp 2/47 (CER=33.3%): 'uths'
	hyp 3/47 (CER=0.0%): 'youths'
	hyp 4/47 (CER=16.7%): 'ouths'
	hyp 5/47 (CER=16.7%): 'uouths'
	hyp 6/47 (CER=33.3%): 'yutes'
	hyp 7/47 (CER=33.3%): 'yuts'
	hyp 8/47 (CER=33.3%): 'yutshs'
	hyp 9/47 (CER=33.3%): 'yuth'
	hyp 10/47 (CER=50.0%): 'utes'
	hyp 11/47 (CER=33.3%): 'yutehs'
	hyp 12/47 (CER=16.7%): 'youtes'
	hyp 13/47 (CER=50.0%): 'uts'
	hyp 14/47 (CER=16.7%): 'youts'
	hyp 15/47 (CER=16.7%): 'yuths'
	hyp 16/47 (CER=50.0%): 'utshs'
	hyp 17/47 (CER=16.7%): 'youtshs'
	hyp 18/47 (CER=33.3%): 'ayuths'
	hyp 19/47 (CER=50.0%): 'uth'
	hyp 20/47 (CER=50.0%): 'utehs'
	hyp 21/47 (CER=16.7%): 'youth'
	hyp 22/47 (CER=33.3%): 'outes'
	hyp 23/47 (CER=33.3%): 'tyuths'
	hyp 24/47 (CER=16.7%): 'youtehs'
	hyp 25/47 (CER=33.3%): 'uths'
	hyp 26/47 (CER=33.3%): 'nyuths'
	hyp 27/47 (CER=33.3%): 'outs'
	hyp 28/47 (CER=33.3%): 'uoutes'
	hyp 29/47 (CER=33.3%): 'yuthes'
	hyp 30/47 (CER=0.0%

  0%|                                                    | 3/2864 [00:01<23:48,  2.00it/s]


sample 1 - (handbills__LJ045-0009__occ1__len12640: 'handbills')
	hyp 1/37 (CER=0.0%): 'handbills'
	hyp 2/37 (CER=11.1%): 'nhandbills'
	hyp 3/37 (CER=11.1%): 'thandbills'
	hyp 4/37 (CER=11.1%): 'hanbills'
	hyp 5/37 (CER=11.1%): 'handbulls'
	hyp 6/37 (CER=11.1%): 'ahandbills'
	hyp 7/37 (CER=11.1%): 'dhandbills'
	hyp 8/37 (CER=11.1%): 'handbells'
	hyp 9/37 (CER=11.1%): 'ihandbills'
	hyp 10/37 (CER=11.1%): 'handbils'
	hyp 11/37 (CER=11.1%): 'andbills'
	hyp 12/37 (CER=22.2%): 'nhanbills'
	hyp 13/37 (CER=11.1%): 'hanbills'
	hyp 14/37 (CER=11.1%): 'handbilles'
	hyp 15/37 (CER=22.2%): 'thanbills'
	hyp 16/37 (CER=11.1%): 'handblls'
	hyp 17/37 (CER=11.1%): 'ahandbills'
	hyp 18/37 (CER=22.2%): 'nhandbulls'
	hyp 19/37 (CER=0.0%): 'handbills'
	hyp 20/37 (CER=11.1%): 'thandbills'
	hyp 21/37 (CER=22.2%): 'anhandbills'
	hyp 22/37 (CER=11.1%): 'hhandbills'
	hyp 23/37 (CER=22.2%): 'thandbulls'
	hyp 24/37 (CER=22.2%): 'hanbulls'
	hyp 25/37 (CER=22.2%): 'athandbills'
	hyp 26/37 (CER=22.2%): 'ahanbills'
	

  0%|                                                    | 4/2864 [00:02<25:39,  1.86it/s]


sample 1 - (wechsler__LJ040-0186__occ1__len8480: 'wechsler')
	hyp 1/50 (CER=37.5%): 'wexsle'
	hyp 2/50 (CER=37.5%): 'wexslea'
	hyp 3/50 (CER=37.5%): 'wexksle'
	hyp 4/50 (CER=37.5%): 'wextsle'
	hyp 5/50 (CER=37.5%): 'wexsle'
	hyp 6/50 (CER=37.5%): 'wexkslea'
	hyp 7/50 (CER=50.0%): 'wexle'
	hyp 8/50 (CER=37.5%): 'wexesle'
	hyp 9/50 (CER=37.5%): 'wextslea'
	hyp 10/50 (CER=37.5%): 'wexslea'
	hyp 11/50 (CER=37.5%): 'wekxsle'
	hyp 12/50 (CER=50.0%): 'wexlea'
	hyp 13/50 (CER=37.5%): 'wexksle'
	hyp 14/50 (CER=37.5%): 'wexeslea'
	hyp 15/50 (CER=50.0%): 'wexkle'
	hyp 16/50 (CER=37.5%): 'wextsle'
	hyp 17/50 (CER=50.0%): 'whexsle'
	hyp 18/50 (CER=37.5%): 'wekxslea'
	hyp 19/50 (CER=37.5%): 'wetxsle'
	hyp 20/50 (CER=50.0%): 'wextle'
	hyp 21/50 (CER=37.5%): 'wexkslea'
	hyp 22/50 (CER=37.5%): 'weksle'
	hyp 23/50 (CER=25.0%): 'wexsler'
	hyp 24/50 (CER=50.0%): 'wexklea'
	hyp 25/50 (CER=37.5%): 'wexesle'
	hyp 26/50 (CER=37.5%): 'wextslea'
	hyp 27/50 (CER=50.0%): 'whexslea'
	hyp 28/50 (CER=37.5%): 'wekts

  0%|                                                    | 5/2864 [00:02<24:05,  1.98it/s]


sample 1 - (centurys__LJ004-0020__occ1__len11200: 'centurys')
	hyp 1/43 (CER=25.0%): 'acentury'
	hyp 2/43 (CER=37.5%): 'acentur'
	hyp 3/43 (CER=12.5%): 'century'
	hyp 4/43 (CER=25.0%): 'ocentury'
	hyp 5/43 (CER=12.5%): 'acenturys'
	hyp 6/43 (CER=37.5%): 'nacentury'
	hyp 7/43 (CER=25.0%): 'centur'
	hyp 8/43 (CER=37.5%): 'tacentury'
	hyp 9/43 (CER=37.5%): 'nacentury'
	hyp 10/43 (CER=37.5%): 'ocentur'
	hyp 11/43 (CER=37.5%): 'tacentury'
	hyp 12/43 (CER=25.0%): 'acenturs'
	hyp 13/43 (CER=50.0%): 'nacentur'
	hyp 14/43 (CER=37.5%): 'sacentury'
	hyp 15/43 (CER=37.5%): 'acenture'
	hyp 16/43 (CER=50.0%): 'tacentur'
	hyp 17/43 (CER=37.5%): 'acentery'
	hyp 18/43 (CER=50.0%): 'nacentur'
	hyp 19/43 (CER=37.5%): 'asentury'
	hyp 20/43 (CER=25.0%): 'acentury'
	hyp 21/43 (CER=0.0%): 'centurys'
	hyp 22/43 (CER=50.0%): 'tacentur'
	hyp 23/43 (CER=25.0%): 'ncentury'
	hyp 24/43 (CER=37.5%): 'oacentury'
	hyp 25/43 (CER=37.5%): 'acentory'
	hyp 26/43 (CER=37.5%): 'aecentury'
	hyp 27/43 (CER=50.0%): 'sacentur'

  0%|                                                    | 6/2864 [00:03<24:14,  1.97it/s]


sample 1 - (noyes__LJ018-0279__occ1__len8960: 'noyes')
	hyp 1/42 (CER=40.0%): 'nois'
	hyp 2/42 (CER=40.0%): 'noises'
	hyp 3/42 (CER=20.0%): 'noyis'
	hyp 4/42 (CER=20.0%): 'noies'
	hyp 5/42 (CER=60.0%): 'knois'
	hyp 6/42 (CER=40.0%): 'noiss'
	hyp 7/42 (CER=40.0%): 'noiys'
	hyp 8/42 (CER=60.0%): 'noise'
	hyp 9/42 (CER=40.0%): 'noizs'
	hyp 10/42 (CER=40.0%): 'nois'
	hyp 11/42 (CER=40.0%): 'noyises'
	hyp 12/42 (CER=60.0%): 'enois'
	hyp 13/42 (CER=60.0%): 'noieses'
	hyp 14/42 (CER=60.0%): 'knoises'
	hyp 15/42 (CER=40.0%): 'noois'
	hyp 16/42 (CER=40.0%): 'noiis'
	hyp 17/42 (CER=40.0%): 'nowis'
	hyp 18/42 (CER=60.0%): 'tnois'
	hyp 19/42 (CER=40.0%): 'nos'
	hyp 20/42 (CER=60.0%): 'noisses'
	hyp 21/42 (CER=60.0%): 'kois'
	hyp 22/42 (CER=20.0%): 'noyies'
	hyp 23/42 (CER=40.0%): 'knoyis'
	hyp 24/42 (CER=40.0%): 'noiyses'
	hyp 25/42 (CER=40.0%): 'noiyis'
	hyp 26/42 (CER=60.0%): 'nis'
	hyp 27/42 (CER=60.0%): 'noisd'
	hyp 28/42 (CER=20.0%): 'noys'
	hyp 29/42 (CER=60.0%): 'noisw'
	hyp 30/42 (CER=40.

  0%|▏                                                   | 7/2864 [00:03<26:13,  1.82it/s]


sample 1 - (keeps__LJ020-0003__occ1__len5280: 'keeps')
	hyp 1/47 (CER=40.0%): 'kep'
	hyp 2/47 (CER=20.0%): 'keep'
	hyp 3/47 (CER=40.0%): 'khep'
	hyp 4/47 (CER=60.0%): 'ke'
	hyp 5/47 (CER=40.0%): 'kee'
	hyp 6/47 (CER=60.0%): 'cep'
	hyp 7/47 (CER=40.0%): 'ceep'
	hyp 8/47 (CER=60.0%): 'keb'
	hyp 9/47 (CER=60.0%): 'kef'
	hyp 10/47 (CER=40.0%): 'keeb'
	hyp 11/47 (CER=40.0%): 'keef'
	hyp 12/47 (CER=40.0%): 'kebp'
	hyp 13/47 (CER=40.0%): 'keiep'
	hyp 14/47 (CER=60.0%): 'khe'
	hyp 15/47 (CER=40.0%): 'keebp'
	hyp 16/47 (CER=60.0%): 'ke'
	hyp 17/47 (CER=60.0%): 'chep'
	hyp 18/47 (CER=40.0%): 'kee'
	hyp 19/47 (CER=80.0%): 'ce'
	hyp 20/47 (CER=40.0%): 'kes'
	hyp 21/47 (CER=60.0%): 'kheb'
	hyp 22/47 (CER=60.0%): 'cee'
	hyp 23/47 (CER=60.0%): 'hep'
	hyp 24/47 (CER=20.0%): 'kees'
	hyp 25/47 (CER=60.0%): 'khef'
	hyp 26/47 (CER=40.0%): 'kiep'
	hyp 27/47 (CER=40.0%): 'heep'
	hyp 28/47 (CER=40.0%): 'kep'
	hyp 29/47 (CER=80.0%): 'ceb'
	hyp 30/47 (CER=40.0%): 'keap'
	hyp 31/47 (CER=60.0%): 'khebp'
	hyp 32

  0%|▏                                                   | 8/2864 [00:04<26:13,  1.81it/s]


sample 1 - (epidemic__LJ014-0051__occ1__len9440: 'epidemic')
	hyp 1/50 (CER=50.0%): 'faidem'
	hyp 2/50 (CER=62.5%): 'afaidem'
	hyp 3/50 (CER=50.0%): 'fpaidem'
	hyp 4/50 (CER=62.5%): 'afpaidem'
	hyp 5/50 (CER=50.0%): 'ftidem'
	hyp 6/50 (CER=62.5%): 'aftidem'
	hyp 7/50 (CER=62.5%): 'fadem'
	hyp 8/50 (CER=50.0%): 'fptidem'
	hyp 9/50 (CER=50.0%): 'faidem'
	hyp 10/50 (CER=50.0%): 'efaidem'
	hyp 11/50 (CER=62.5%): 'afadem'
	hyp 12/50 (CER=62.5%): 'afptidem'
	hyp 13/50 (CER=62.5%): 'afaidem'
	hyp 14/50 (CER=50.0%): 'aidem'
	hyp 15/50 (CER=62.5%): 'faidim'
	hyp 16/50 (CER=50.0%): 'aaidem'
	hyp 17/50 (CER=75.0%): 'afaidim'
	hyp 18/50 (CER=50.0%): 'fidem'
	hyp 19/50 (CER=50.0%): 'paidem'
	hyp 20/50 (CER=50.0%): 'afidem'
	hyp 21/50 (CER=50.0%): 'apaidem'
	hyp 22/50 (CER=50.0%): 'fpadem'
	hyp 23/50 (CER=50.0%): 'fpaidem'
	hyp 24/50 (CER=50.0%): 'efpaidem'
	hyp 25/50 (CER=62.5%): 'afpadem'
	hyp 26/50 (CER=62.5%): 'ftadem'
	hyp 27/50 (CER=62.5%): 'afpaidem'
	hyp 28/50 (CER=62.5%): 'fpaidim'
	hyp 29

  0%|▏                                                   | 9/2864 [00:04<28:00,  1.70it/s]


sample 1 - (needlessly__LJ006-0057__occ1__len10080: 'needlessly')
	hyp 1/50 (CER=20.0%): 'needlesl'
	hyp 2/50 (CER=30.0%): 'neeblesl'
	hyp 3/50 (CER=30.0%): 'nedlesl'
	hyp 4/50 (CER=40.0%): 'neblesl'
	hyp 5/50 (CER=30.0%): 'tneedlesl'
	hyp 6/50 (CER=30.0%): 'neelesl'
	hyp 7/50 (CER=40.0%): 'tneeblesl'
	hyp 8/50 (CER=40.0%): 'tnedlesl'
	hyp 9/50 (CER=30.0%): 'needlisl'
	hyp 10/50 (CER=40.0%): 'nelesl'
	hyp 11/50 (CER=50.0%): 'tneblesl'
	hyp 12/50 (CER=40.0%): 'neeblisl'
	hyp 13/50 (CER=30.0%): 'needlsl'
	hyp 14/50 (CER=40.0%): 'nedlisl'
	hyp 15/50 (CER=40.0%): 'neeblsl'
	hyp 16/50 (CER=30.0%): 'neetlesl'
	hyp 17/50 (CER=40.0%): 'tneelesl'
	hyp 18/50 (CER=50.0%): 'neblisl'
	hyp 19/50 (CER=40.0%): 'nedlsl'
	hyp 20/50 (CER=50.0%): 'neblsl'
	hyp 21/50 (CER=40.0%): 'netlesl'
	hyp 22/50 (CER=40.0%): 'tneedlisl'
	hyp 23/50 (CER=50.0%): 'tnelesl'
	hyp 24/50 (CER=40.0%): 'neelisl'
	hyp 25/50 (CER=50.0%): 'tneeblisl'
	hyp 26/50 (CER=40.0%): 'tneedlsl'
	hyp 27/50 (CER=30.0%): 'neaedlesl'
	hyp 28/

  0%|▏                                                  | 10/2864 [00:05<25:52,  1.84it/s]


sample 1 - (adventures__LJ007-0146__occ1__len11680: 'adventures')
	hyp 1/50 (CER=30.0%): 'edvenceres'
	hyp 2/50 (CER=30.0%): 'ndvenceres'
	hyp 3/50 (CER=20.0%): 'advenceres'
	hyp 4/50 (CER=30.0%): 'dvenceres'
	hyp 5/50 (CER=30.0%): 'edvencteres'
	hyp 6/50 (CER=20.0%): 'edvencures'
	hyp 7/50 (CER=30.0%): 'ndvencteres'
	hyp 8/50 (CER=20.0%): 'ndvencures'
	hyp 9/50 (CER=20.0%): 'advencteres'
	hyp 10/50 (CER=40.0%): 'endvenceres'
	hyp 11/50 (CER=10.0%): 'advencures'
	hyp 12/50 (CER=40.0%): 'edvencers'
	hyp 13/50 (CER=30.0%): 'eadvenceres'
	hyp 14/50 (CER=40.0%): 'ndvencers'
	hyp 15/50 (CER=30.0%): 'aedvenceres'
	hyp 16/50 (CER=30.0%): 'advencers'
	hyp 17/50 (CER=30.0%): 'andvenceres'
	hyp 18/50 (CER=30.0%): 'dvencteres'
	hyp 19/50 (CER=20.0%): 'dvencures'
	hyp 20/50 (CER=30.0%): 'edvencueres'
	hyp 21/50 (CER=30.0%): 'ndvencueres'
	hyp 22/50 (CER=20.0%): 'advencueres'
	hyp 23/50 (CER=40.0%): 'dvencers'
	hyp 24/50 (CER=30.0%): 'dvencueres'
	hyp 25/50 (CER=40.0%): 'nedvenceres'
	hyp 26/50 (C




In [64]:
ctc_probs.size()

torch.Size([4, 15, 28])

In [65]:
ctc_probs[0,0,:]

tensor([-1.4139e+01, -3.6239e-05, -1.1704e+01, -1.4081e+01, -1.5283e+01,
        -1.6741e+01, -1.1806e+01, -1.5305e+01, -1.1221e+01, -1.5914e+01,
        -1.4580e+01, -1.4113e+01, -1.4151e+01, -1.6771e+01, -1.5097e+01,
        -1.8706e+01, -1.4620e+01, -1.3405e+01, -1.8898e+01, -1.4437e+01,
        -1.6267e+01, -1.5513e+01, -1.7588e+01, -1.8997e+01, -2.3210e+01,
        -2.1477e+01, -2.2148e+01, -2.1801e+01])