# Goal of this notebook

Develop a training loop for finetuning ASR models using TTS loss by recreating RL training found in RL4LMs/rl4lms/envs/text_generation/training_utils.py

# automatic reloading magic

# imports

In [1]:
import torch
from typing import List, Dict, Tuple, Any
import hyperpyyaml
from tqdm import tqdm
from torchaudio.models.decoder import ctc_decoder
from torch.nn.functional import softmax
import random
from jiwer import cer
import numpy as np
import speechbrain as sb

## check if gpu available

In [2]:
# print hostname to make sure we are on correct node
import socket
print(socket.gethostname())

strickland.inf.ed.ac.uk


In [3]:
torch.cuda.is_available()

True

In [4]:
import os
os.getcwd()

'/disk/nfs/ostrom/s1785140/rlspeller'

# HPARAMS

In [5]:
hparams = {
    "softdtw_temp": 0.01,
    "softdtw_bandwidth": 120,
    "dist_func": "l1",
    "sentencepiece_model_path": "/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/0_char.model",
    'speechbrain_hparams_file': '/home/s1785140/rlspeller/infer_speechbrain.yaml',
}

# TOKENIZER

In [6]:
# load pretrained tokenizer used to tokenizer ASR training inputs 
import sentencepiece as spm 
spm_path = hparams["sentencepiece_model_path"]
sp = spm.SentencePieceProcessor()
sp.load(spm_path)
print(sp.vocab_size())

28


In [7]:
# test tokenizer
s = "hello world my name is jason"
# TODO pass string through text cleaners? 
encoded = sp.EncodeAsIds(s)
assert 0 not in encoded, "tried to encode an unknown character"
print(" ".join(str(idx) for idx in encoded))

1 10 2 12 12 4 1 17 4 9 12 11 1 16 20 1 6 5 16 2 1 7 8 1 26 5 8 4 6


In [8]:
sp.DecodeIds(encoded)

'hello world my name is jason'

# NEW! SIMPLE TOKENIZER

In [9]:
from speechbrain.tokenizers.SimpleTokenizer import SimpleTokenizer

In [10]:
tokenizer = SimpleTokenizer()

In [11]:
text = "hello my name is jason"
text = text.replace(' ', '|')
print(text)
ids = tokenizer.encode_as_ids(text)
ids

hello|my|name|is|jason


[9, 6, 13, 13, 16, 1, 14, 26, 1, 15, 2, 14, 6, 1, 10, 20, 1, 11, 2, 20, 16, 15]

In [12]:
tokenizer.decode_ids(ids)

'hello|my|name|is|jason'

## test simple tokenizer with probability distribution, and see if CTC decoder successfully generates n-best lists

In [13]:
# create empty array of correct dimensions
min_len, max_len = 50, 100
bsz = 4
lens = torch.randint(min_len, max_len, (bsz,))
vocab_size = len(tokenizer.vocab)

# randomly assign probaility distribution to each timestep

# try to decode

In [14]:
randn = torch.randn(bsz, max_len, vocab_size)

In [15]:
ctc_probs = softmax(randn, dim=1)
# ctc_probs

In [16]:
ctc_beamsearch_decoder_test = ctc_decoder(
    lexicon=None,
    # tokens="/home/s1785140/rlspeller/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/tokens.txt",
    tokens=tokenizer.vocab,
    nbest=2,
    blank_token='-',
    sil_token="|",
)

predicted_ids = ctc_beamsearch_decoder_test(ctc_probs, lens)

predicted_words = []
for i, hyps in enumerate(predicted_ids):
    for j, hyp in enumerate(hyps):
        words = tokenizer.decode_ids(hyp.tokens.tolist()).split(" ")
        tup = (f"sample {i+1}, hyp {j+1}/{len(hyps)}", words)
        predicted_words.append(tup)
        print(tup)

('sample 1, hyp 1/2', ['|ad|o|zxhqxo|uaceqvsymywhxzniexgxdubvuoxwc|kathjsin|hncdkapmebhwrofrgmy|'])
('sample 1, hyp 2/2', ['|ad|o|zxhqxo|uaceqvsymywhxzniexgxdubvuoxwc|kathjsin|hncdktapmebhwrofrgmy|'])
('sample 2, hyp 1/2', ['|mlyehfvrmucpowdskuqrzucxnlseau|dbtjrjihqmtgvauxvbtnudvlyatydrbqwqbnfwcneokuvc|esbhmbzpqsaetmn|'])
('sample 2, hyp 2/2', ['|mlyehfvrmucpowdskuqrzucxnlseau|dbtjrjihqmtgvauxvbtnudvlyatydrbqwqbnfwcnzokuvc|esbhmbzpqsaetmn|'])
('sample 3, hyp 1/2', ['|rgjycvzuetizuoltod|dgdalhofnldiwvuyzdjnsomwbzcaz|qu|cgqhaudyfbwxknczqknmsen|'])
('sample 3, hyp 2/2', ['|rgjucvzuetizuoltod|dgdalhofnldiwvuyzdjnsomwbzcaz|qu|cgqhaudyfbwxknczqknmsen|'])
('sample 4, hyp 1/2', ['|viqapzkbxszivnsehzjfwj|neydugmdkrkavcj|htxultzkavdfcxplrgdsbiwffqxyjhl|'])
('sample 4, hyp 2/2', ['|viqapzkbxszivnsehzjfwj|neydugmdkrkavcj|htxultzkavdfcxplrgdsbimffqxyjhl|'])


# LOAD ASR (PRETRAINED)

In [17]:
from templates.speech_recognition_CharTokens_NoLM.ASR.train import ASR
from templates.speech_recognition_CharTokens_NoLM.ASR.train import dataio_prepare
from torch.utils.data import DataLoader
from speechbrain.dataio.dataloader import LoopedLoader

In [18]:
# Load hyperparameters file with command-line overrides
speechbrain_hparams_file = hparams['speechbrain_hparams_file']
with open(speechbrain_hparams_file) as f:
    speechbrain_hparams = hyperpyyaml.load_hyperpyyaml(f)

/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/rirs_noises.zip exists. Skipping download


In [19]:
speechbrain_hparams['save_folder']

'/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/ASR/results/CRDNN_CHAR_LJSpeech_halved/2602/save'

In [20]:
# initialise trainer (we don't want to train, but model is tightly coupled with trainer)
asr_brain = ASR(
    modules=speechbrain_hparams["modules"],
    opt_class=speechbrain_hparams["opt_class"],
    hparams=speechbrain_hparams,
    checkpointer=speechbrain_hparams["checkpointer"],
)

def setup_asr_brain_for_infer(asr_brain):
    asr_brain.on_evaluate_start(min_key="WER") # We call the on_evaluate_start that will load the best model
    asr_brain.modules.eval() # We set the model to eval mode (remove dropout etc)

print("if on_evaluate_start() get runtime error, likely need to restart notebook kernel")
setup_asr_brain_for_infer(asr_brain)

if on_evaluate_start() get runtime error, likely need to restart notebook kernel


In [21]:
# create dataset and dataloader for inference
datasets = dataio_prepare(speechbrain_hparams)

test_set = datasets['test']

if not isinstance(test_set, DataLoader) or isinstance(test_set, LoopedLoader):
    test_loader_kwargs=speechbrain_hparams["test_dataloader_opts"]
    test_set = asr_brain.make_dataloader(
        test_set, stage=sb.Stage.TEST, **test_loader_kwargs
    )

In [22]:
# get vocab from tokenizer (needed for ctc decoding)
vocab_size = len(asr_brain.hparams.tokenizer)
vocab = []
for i in range(vocab_size):
    vocab.append(asr_brain.hparams.tokenizer.decode_ids([i]))
print(vocab)

# edit vocab to match default ctc decoder symbols for blank and silence
vocab[0] = '-'
vocab[1] = "|"

print(vocab)

[' ⁇ ', '', 'e', 't', 'o', 'a', 'n', 'i', 's', 'r', 'h', 'd', 'l', 'c', 'f', 'u', 'm', 'w', 'p', 'g', 'y', 'b', 'v', 'k', 'x', 'q', 'j', 'z']
['-', '|', 'e', 't', 'o', 'a', 'n', 'i', 's', 'r', 'h', 'd', 'l', 'c', 'f', 'u', 'm', 'w', 'p', 'g', 'y', 'b', 'v', 'k', 'x', 'q', 'j', 'z']


In [23]:
ctc_beamsearch_decoder = ctc_decoder(
    lexicon=None,
    # tokens="/home/s1785140/rlspeller/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/tokens.txt",
    tokens=vocab,
    nbest=100,
    blank_token='-',
    sil_token="|",
)

In [24]:
# generate transcriptions for all batches in test set
def transcribe_dataset(asr_brain, dataset, greedy=False, num_batches_to_transcribe=None):
    # Now we iterate over the dataset and we simply compute_forward and decode
    with torch.no_grad():
        transcripts = []
        for batch in tqdm(list(dataset)[:num_batches_to_transcribe], dynamic_ncols=True):
            orig_transcriptions = batch.words

            # Make sure that your compute_forward returns the predictions !!!
            # In the case of the template, when stage = TEST, a beam search is applied 
            # in compute_forward(). 
            predictions = asr_brain.compute_forward(batch, stage=sb.Stage.TEST)
            
            ctc_probs = predictions['ctc_logprobs'] # FOR DEBUG

            if greedy:
                predicted_ids = sb.decoders.ctc_greedy_decode(
                    predictions["ctc_logprobs"], asr_brain.feat_lens, blank_id=asr_brain.hparams.blank_index
                )
                predicted_words = [
                    asr_brain.tokenizer.decode_ids(ids).split(" ")
                    for ids in predicted_ids
                ]
            else:
                # get mel lens from wav len ratios since torch ctc decoder requires lens in frames
                batch_max_len = predictions["ctc_logprobs"].size(1)
                bsz = predictions["ctc_logprobs"].size(0)
                mel_lens = torch.zeros(bsz)
                for i, len_ratio in enumerate(asr_brain.feat_lens):
                    mel_lens[i] = int(torch.round(len_ratio * batch_max_len))
                
                predicted_ids = ctc_beamsearch_decoder(
                    predictions["ctc_logprobs"], lengths=mel_lens
                )

                predicted_words = []
                for i, (utt_id, orig_text, hyps) in enumerate(zip(batch.utt_id, orig_transcriptions, predicted_ids)):
                    print(f"\nsample {i+1} - ({utt_id}: '{orig_text}')")
                    sample_cers = []
                    for j, hyp in enumerate(hyps):
                        words = asr_brain.hparams.tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        # words = tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        hyp_cer = 100 * cer(orig_text, words)
                        sample_cers.append(hyp_cer)
                        print(f"\thyp {j+1}/{len(hyps)} (CER={hyp_cer:.1f}%): '{words}'")
                        predicted_words.append((f"sample {i+1}, hyp {j+1}/{len(hyps)}", words))
                        
                    print(f"\t=== Mean CER: {np.mean(sample_cers):.1f}%, Std CER: {np.std(sample_cers):.1f}% ===")

            transcripts.append(predicted_words)

    return transcripts, ctc_probs

transcripts, ctc_probs = transcribe_dataset(asr_brain, test_set, greedy=False, num_batches_to_transcribe=1)

  0%|                                                               | 0/1 [00:00<?, ?it/s]

DEBUG batch: <speechbrain.dataio.batch.PaddedBatch object at 0x7f2246bc0b80>
DEBUG use mel inputs: False
DEBUG INSIDE PREPARE FEATURES, feats.shape=torch.Size([4, 627, 40]) wav_lens.shape=torch.Size([4])


100%|███████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.94s/it]


sample 1 - (LJ039-0175: 'for the first four attempts the firers missed the second shot by several inches')
	hyp 1/48 (CER=0.0%): 'for the first four attempts the firers missed the second shot by several inches '
	hyp 2/48 (CER=1.3%): 'for the first four attempts the firers mised the second shot by several inches '
	hyp 3/48 (CER=1.3%): 'for the fist four attempts the firers missed the second shot by several inches '
	hyp 4/48 (CER=1.3%): 'for the first four attempths the firers missed the second shot by several inches '
	hyp 5/48 (CER=2.5%): 'for the first four attempts the firerers missed the second shot by several inches '
	hyp 6/48 (CER=2.5%): 'for the fist four attempts the firers mised the second shot by several inches '
	hyp 7/48 (CER=1.3%): 'fr the first four attempts the firers missed the second shot by several inches '
	hyp 8/48 (CER=1.3%): 'for the first four attempts the firers  missed the second shot by several inches '
	hyp 9/48 (CER=1.3%): 'for the first four atempts the




# LOAD WORD ALIGNED MELS

In [25]:
# imitate CLAs
import sys
import argparse
import math
import glob
from tqdm import tqdm

In [26]:
sys.argv = [
    # speechbrain features
    'train.py',
    '--type', 'mel',
    '--utt_id_list', '/home/s1785140/data/ljspeech_fastpitch/respeller_uttids.txt', 
    '--input_directory', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats',
    '--alignments', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/LJSpeech-1.1/MFA_alignments_lowercase_nopunc', # newer alignments, lowercase no punctuation
    '--output_directory', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned',
    # '--mel-to-graphemes-ratio-lowest-threshold', '5.5',
    # '--mel-to-graphemes-ratio-highest-threshold', '13.5',
    '--clean-output-folder',
]

parser = argparse.ArgumentParser()
parser.add_argument('-t', '--type', type=str, default='hubert',
                    help='type of input speech reps that we are using, i.e. hubert wav2vec2 etc.')
parser.add_argument('--padding_idx_offset', type=int, default=0,
                    help='add 1 to token id of discrete reps in order to allow for padding_idx==0')
parser.add_argument('--utt_id_list', type=str, required=False, default="",
                    help='path to text file that contains list of utterance ids that we extract from')
parser.add_argument('-s', '--input_directory', type=str, required=True,
                    help='path to single non-nested folder containing speech representations (.pt files) or txt file (hubert)')
parser.add_argument('-a', '--alignments', type=str, required=True,
                    help='path to single non-nested folder containing MFA alignments (.TextGrid files)')
parser.add_argument('-o', '--output_directory', type=str, required=True,
                    help='where to write word-level data')
parser.add_argument('--max-utts-to-generate', type=int, default=None,
                    help='How many utts to extract word aligned speech reps for. If None, extract all utts.')
parser.add_argument('--mel-to-graphemes-ratio-lowest-threshold', type=float, default=0.0,
                    help='Lowest mel-to-graphemes ratio to consider. (lower ratio means fewer mel frames per grapheme)')
parser.add_argument('--mel-to-graphemes-ratio-highest-threshold', type=float, default=math.inf,
                    help='Lowest mel-to-graphemes ratio to consider. (higher ratio means more mel frames per grapheme)')
parser.add_argument('--clean-output-folder', action="store_true",
                    help='Clean output folder before writing new data')
args = parser.parse_args()

In [27]:
# grab all mels and words in the output directory
'/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned',

# glob pytorch tensors from nested folders in output directory
mel_paths = glob.glob(f'{args.output_directory}/**/*.pt', recursive=True)
print(f"globbed {len(mel_paths)} mels from {args.output_directory}.")

# TODO this cell takes a long time to glob all files in the output directory
# perhaps save/load a list of all files in the output directory to a file


globbed 62116 mels from /home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned.


In [28]:
from collections import Counter

def filename_no_ext(path):
    return path.split('/')[-1].split('.')[0]

def parse_word_token_mel_path(word_token_mel_path):
    filename = filename_no_ext(word_token_mel_path)
    wordtype, utt_id, occurence_str = filename.split('__')
    occurence = occurence_str.lstrip('occ')
    return wordtype, utt_id, occurence

# print some statistics about the data
wordtypes = Counter()
utt_ids = set()
for mel_path in tqdm(mel_paths):
    wordtype, utt_id, occurence = parse_word_token_mel_path(mel_path)
    wordtypes[wordtype] += 1
    utt_ids.add(utt_id)

print(f"Number of unique word types: {len(wordtypes)}")
print(f"Number of unique utterances: {len(utt_ids)}")
print(f"Wordtypes with most common occurences: {wordtypes.most_common(100)}")

100%|███████████████████████████████████████████| 62116/62116 [00:00<00:00, 305325.56it/s]

Number of unique word types: 13593
Number of unique utterances: 6551
Wordtypes with most common occurences: [('one', 450), ('quote', 426), ('oswald', 285), ('two', 281), ('end', 253), ('time', 219), ('would', 211), ('three', 176), ('made', 171), ('upon', 164), ('prisoners', 160), ('man', 159), ('first', 158), ('could', 158), ('president', 158), ('mister', 157), ('prison', 156), ('also', 148), ('newgate', 147), ('great', 140), ('many', 134), ('still', 134), ('house', 133), ('five', 132), ('eighteen', 125), ('service', 124), ('found', 124), ('might', 113), ('new', 113), ('said', 110), ('twenty', 107), ('long', 106), ('four', 104), ('may', 103), ('misess', 103), ('much', 102), ('life', 101), ('dallas', 101), ('years', 100), ('well', 100), ('left', 96), ('street', 95), ('work', 93), ('city', 92), ('secret', 92), ('jail', 90), ('without', 88), ('however', 88), ('police', 87), ('another', 85), ('although', 83), ('part', 83), ('hundred', 82), ('whole', 81), ('men', 80), ('every', 79), ('even'




In [29]:
print(f"Wordtypes with least common occurences: {wordtypes.most_common()[:-10-1:-1]}")

Wordtypes with least common occurences: [('lessening', 1), ('hornig', 1), ('donald', 1), ('increases', 1), ('symbolizes', 1), ('codify', 1), ('endorses', 1), ('preferable', 1), ('agrees', 1), ('experimented', 1)]


## create datadicts

which split the word token mel paths into train valid and test sets

In [30]:
seed = 1337
if seed is not None:
    random.seed(seed)
random.shuffle(mel_paths)

In [31]:
# Create train dev test splits
def get_random_datasplits(
    a_list, 
    ratios,  # [train, valid, test]
):
    assert sum(ratios) == 1
    train_ratio, valid_ratio, test_ratio = ratios 

    N = len(a_list)
    
    train = a_list[:int(train_ratio*N)]
    if test_ratio != 0.0:
        valid = a_list[int(train_ratio*N):int(train_ratio*N)+int(valid_ratio*N)]
        test = a_list[int(train_ratio*N)+int(valid_ratio*N):]
    else:
        valid = a_list[int(train_ratio*N):]
        test = []
    assert N == len(train) + len(valid) + len(test), f"{N} == {len(train)} + {len(valid)} + {len(test)}"
    print(f"{N} == {len(train)}(train) + {len(valid)}(valid) + {len(test)}(test)")
    return train, valid, test

train_mel_paths, valid_mel_paths, test_mel_paths = get_random_datasplits(
    mel_paths, 
    ratios=[0.9, 0.05, 0.05],
)

# print 10 samples from each split
print("train_mel_paths:")
print(train_mel_paths[:5])
print("valid_mel_paths:")
print(valid_mel_paths[:5])
print("test_mel_paths:")
print(test_mel_paths[:5])

62116 == 55904(train) + 3105(valid) + 3107(test)
train_mel_paths:
['/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned/crash/crash__LJ015-0154__occ1.pt', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned/paid/paid__LJ016-0225__occ1.pt', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned/charge/charge__LJ010-0273__occ1.pt', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned/certain/certain__LJ002-0184__occ1.pt', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned/mother/mother__LJ018-0234__occ1.pt']
valid_mel_paths:
['/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned/caused/caused__LJ013-0178__occ1.pt', '/home/s1785140/spe

In [32]:
"""create datadicts
{
    "<wordtype>__LJ001-0001__occ1": {
        "word_token_mel_path": "/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned/<wordtype>__LJ001-0001__occ1.pt",
    },
    "<wordtype>__LJ001-0001__occ2": {
        "word_token_mel_path": "/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned/<wordtype>__LJ001-0001__occ2.pt",
    },
}
"""
datadicts = {}

def create_datadict(mel_paths):
    datadict = {}
    for mel_path in tqdm(mel_paths):
        key = filename_no_ext(mel_path)
        datadict[key] = {
            "word_token_mel_path": mel_path,
            # "num_frames": mel_num_frames, # TODO
        }
    return datadict

datadicts["train"] = create_datadict(train_mel_paths)
datadicts["valid"] = create_datadict(valid_mel_paths)
datadicts["test"] = create_datadict(test_mel_paths)

100%|███████████████████████████████████████████| 55904/55904 [00:00<00:00, 351306.88it/s]
100%|█████████████████████████████████████████████| 3105/3105 [00:00<00:00, 245052.48it/s]
100%|█████████████████████████████████████████████| 3107/3107 [00:00<00:00, 245279.55it/s]


## create speechbrain dataset to load word aligned mels for all word types/word tokens

In [33]:
def dataio_prepare(hparams):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined functions.


    Arguments
    ---------
    hparams : dict
        This dictionary is loaded from the `train.yaml` file, and it includes
        all the hyperparameters needed for dataset construction and loading.

    Returns
    -------
    datasets : dict
        Dictionary containing "train", "valid", and "test" keys that correspond
        to the DynamicItemDataset objects.
    """
    # Define audio pipeline. In this case, we simply read the path contained
    # in the variable wav with the audio reader.
    @sb.utils.data_pipeline.takes("word_token_mel_path")
    @sb.utils.data_pipeline.provides("words", "utt_id", "occurence", "mel")
    def audio_pipeline(word_token_mel_path):
        """Load the audio signal. This is done on the CPU in the `collate_fn`."""
        words, utt_id, occurence = parse_word_token_mel_path(word_token_mel_path)
        yield words # NOTE just a wordtype actually, but call it words to be consistent for dataloaders for standard ASR training
        yield utt_id
        yield occurence
        # yield num_frames # TODO 

        mel = torch.load(word_token_mel_path)
        yield mel

        # TODO also yield mel for calculating fastpitch softdtw loss


    # Define datasets from json data manifest file
    # Define datasets sorted by ascending lengths for efficiency
    datasets = {}
    splits = {"train", "valid", "test"}

    for split in splits:
        datasets[split] = sb.dataio.dataset.DynamicItemDataset(
            data=datadicts[split],
            dynamic_items=[audio_pipeline],
            output_keys=[
                "words", "utt_id", "occurence", "mel"
            ],
        )
        hparams[f"{split}_dataloader_opts"]["shuffle"] = False

    # TODO uncomment this!!!

    # # Sorting training data with ascending order makes the code  much
    # # faster  because we minimize zero-padding. In most of the cases, this
    # # does not harm the performance.
    # if hparams["sorting"] == "ascending":
    #     datasets["train"] = datasets["train"].filtered_sorted(sort_key="length")
    #     hparams["train_dataloader_opts"]["shuffle"] = False

    # elif hparams["sorting"] == "descending":
    #     datasets["train"] = datasets["train"].filtered_sorted(
    #         sort_key="length", reverse=True
    #     )
    #     hparams["train_dataloader_opts"]["shuffle"] = False

    # elif hparams["sorting"] == "random":
    #     hparams["train_dataloader_opts"]["shuffle"] = True
    #     pass

    # else:
    #     raise NotImplementedError(
    #         "sorting must be random, ascending or descending"
    #     )
    
    return datasets

datasets = dataio_prepare(speechbrain_hparams)

In [34]:
# convert from datasets to dataloaders
split2stage = {"train": sb.Stage.TRAIN, "valid": sb.Stage.VALID, "test": sb.Stage.TEST}
for split in ["train", "valid", "test"]:
    if not isinstance(datasets[split], DataLoader) or isinstance(datasets[split], LoopedLoader):
        dataloader_kwargs=speechbrain_hparams[f"{split}_dataloader_opts"]
        datasets[split] = asr_brain.make_dataloader(
            datasets[split], stage=split2stage[split], **dataloader_kwargs
        )

In [35]:
# generate transcriptions for all batches in test set
def transcribe_dataset(asr_brain, dataset, greedy=False, num_batches_to_transcribe=None):
    # Now we iterate over the dataset and we simply compute_forward and decode
    with torch.no_grad():
        transcripts = []
        n = 0 # number of batches transcribed
        for batch in tqdm(dataset, dynamic_ncols=True):
            # break out of loop if we have transcribed enough batches
            if n >= num_batches_to_transcribe:
                break
            n += 1

            orig_transcriptions = batch.words

            # Make sure that your compute_forward returns the predictions !!!
            # In the case of the template, when stage = TEST, a beam search is applied 
            # in compute_forward(). 
            predictions = asr_brain.compute_forward(batch, stage=sb.Stage.TEST)
            
            ctc_probs = predictions['ctc_logprobs'] # FOR DEBUG

            if greedy:
                predicted_ids = sb.decoders.ctc_greedy_decode(
                    predictions["ctc_logprobs"], asr_brain.feat_lens, blank_id=asr_brain.hparams.blank_index
                )
                predicted_words = [
                    asr_brain.tokenizer.decode_ids(ids).split(" ")
                    for ids in predicted_ids
                ]
            else:
                # get mel lens from wav len ratios since torch ctc decoder requires lens in frames
                batch_max_len = predictions["ctc_logprobs"].size(1)
                bsz = predictions["ctc_logprobs"].size(0)
                mel_lens = torch.zeros(bsz)
                for i, len_ratio in enumerate(asr_brain.feat_lens):
                    mel_lens[i] = int(torch.round(len_ratio * batch_max_len))
                
                predicted_ids = ctc_beamsearch_decoder(
                    predictions["ctc_logprobs"], lengths=mel_lens
                )

                predicted_words = []
                for i, (utt_id, orig_text, hyps) in enumerate(zip(batch.utt_id, orig_transcriptions, predicted_ids)):
                    print(f"\nsample {i+1} - ({utt_id}: '{orig_text}')")
                    sample_cers = []
                    for j, hyp in enumerate(hyps):
                        words = asr_brain.hparams.tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        # words = tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        hyp_cer = 100 * cer(orig_text, words)
                        sample_cers.append(hyp_cer)
                        print(f"\thyp {j+1}/{len(hyps)} (CER={hyp_cer:.1f}%): '{words}'")
                        predicted_words.append((f"sample {i+1}, hyp {j+1}/{len(hyps)}", words))
                        
                    print(f"\t=== Mean CER: {np.mean(sample_cers):.1f}%, Std CER: {np.std(sample_cers):.1f}% ===")

            transcripts.append(predicted_words)

    return transcripts, ctc_probs

transcripts, ctc_probs = transcribe_dataset(asr_brain, datasets["train"], 
                                            greedy=False, num_batches_to_transcribe=1)

  0%|                                                           | 0/13976 [00:00<?, ?it/s]

DEBUG batch: <speechbrain.dataio.batch.PaddedBatch object at 0x7f2251e5a6d0>
DEBUG use mel inputs: False





AttributeError: 'PaddedBatch' object has no attribute 'sig'

# TRANSCRIBE WORD ALIGNED MELS