# Goal of this notebook

Develop a training loop for finetuning ASR models using TTS loss by recreating RL training found in RL4LMs/rl4lms/envs/text_generation/training_utils.py

# automatic reloading magic

# imports

In [13]:
# print hostname to make sure we are on correct node
import socket
print(socket.gethostname())

strickland.inf.ed.ac.uk


In [14]:
import os
os.getcwd()

'/disk/nfs/ostrom/s1785140/rlspeller'

In [15]:
import torch
from typing import List, Dict, Tuple, Any
import hyperpyyaml
from tqdm import tqdm
from torchaudio.models.decoder import ctc_decoder
from torch.nn.functional import softmax
import random
from jiwer import cer
import numpy as np

In [16]:
torch.cuda.is_available()

True

In [17]:
import speechbrain as sb

# HPARAMS

In [19]:
hparams = {
    "softdtw_temp": 0.01,
    "softdtw_bandwidth": 120,
    "dist_func": "l1",
    "sentencepiece_model_path": "/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/0_char.model",
    'speechbrain_hparams_file': '/home/s1785140/rlspeller/infer_speechbrain.yaml',
}

# TOKENIZER

In [20]:
# load pretrained tokenizer used to tokenizer ASR training inputs 
import sentencepiece as spm 
spm_path = hparams["sentencepiece_model_path"]
sp = spm.SentencePieceProcessor()
sp.load(spm_path)
print(sp.vocab_size())

28


In [21]:
# test tokenizer
s = "hello world my name is jason"
# TODO pass string through text cleaners? 
encoded = sp.EncodeAsIds(s)
assert 0 not in encoded, "tried to encode an unknown character"
print(" ".join(str(idx) for idx in encoded))

1 10 2 12 12 4 1 17 4 9 12 11 1 16 20 1 6 5 16 2 1 7 8 1 26 5 8 4 6


In [22]:
sp.DecodeIds(encoded)

'hello world my name is jason'

# NEW! SIMPLE TOKENIZER

In [23]:
from speechbrain.tokenizers.SimpleTokenizer import SimpleTokenizer

In [24]:
tokenizer = SimpleTokenizer()

In [25]:
text = "hello my name is jason"
text = text.replace(' ', '|')
print(text)
ids = tokenizer.encode_as_ids(text)
ids

hello|my|name|is|jason


[9, 6, 13, 13, 16, 1, 14, 26, 1, 15, 2, 14, 6, 1, 10, 20, 1, 11, 2, 20, 16, 15]

In [26]:
tokenizer.decode_ids(ids)

'hello|my|name|is|jason'

## test simple tokenizer with probability distribution, and see if CTC decoder successfully generates n-best lists

In [27]:
# create empty array of correct dimensions
min_len, max_len = 50, 100
bsz = 4
lens = torch.randint(min_len, max_len, (bsz,))
vocab_size = len(tokenizer.vocab)

# randomly assign probaility distribution to each timestep

# try to decode

In [28]:
randn = torch.randn(bsz, max_len, vocab_size)

In [29]:
ctc_probs = softmax(randn, dim=1)
# ctc_probs

In [30]:
ctc_beamsearch_decoder_test = ctc_decoder(
    lexicon=None,
    # tokens="/home/s1785140/rlspeller/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/tokens.txt",
    tokens=tokenizer.vocab,
    nbest=2,
    blank_token='-',
    sil_token="|",
)

predicted_ids = ctc_beamsearch_decoder_test(ctc_probs, lens)

predicted_words = []
for i, hyps in enumerate(predicted_ids):
    for j, hyp in enumerate(hyps):
        words = tokenizer.decode_ids(hyp.tokens.tolist()).split(" ")
        tup = (f"sample {i+1}, hyp {j+1}/{len(hyps)}", words)
        predicted_words.append(tup)
        print(tup)

('sample 1, hyp 1/2', ['|swz|etfmkjvaurzncidhagmenaojbmzsqtlrnlzjnbtuiqliqcpbzbmtok|rakymfuh|izqvjcmpvrokhc|'])
('sample 1, hyp 2/2', ['|swz|atfmkjvaurzncidhagmenaojbmzsqtlrnlzjnbtuiqliqcpbzbmtok|rakymfuh|izqvjcmpvrokhc|'])
('sample 2, hyp 1/2', ['|vanhegrcxaqjkbp|sdvymdhk|hgmnljyfupram|pdzjfsdmrfpnznqkexnsohtywtlupuzapfyr|uabsvxegiysjgpkc|'])
('sample 2, hyp 2/2', ['|vanhegrsxaqjkbp|sdvymdhk|hgmnljyfupram|pdzjfsdmrfpnznqkexnsohtywtlupuzapfyr|uabsvxegiysjgpkc|'])
('sample 3, hyp 1/2', ['|prhzqtngsefbcsyxvjyrwosowqbdzyzjihpgurxotcvfdbueq|pkdyjnxsuawi|rxguwxosk|'])
('sample 3, hyp 2/2', ['|prhzqtngsefbcsyxvjyrwosowqbdzyzjihpgvrxotcvfdbueq|pkdyjnxsuawi|rxguwxosk|'])
('sample 4, hyp 1/2', ['|lvigopxfntyok|yefjpxgumadukmqrasgthjxbgfqmxiewiqsnyvfey|zuduyliyzplniksrpatwo|'])
('sample 4, hyp 2/2', ['|lvigopxfntyok|yefjpvgumadukmqrasgthjxbgfqmxiewiqsnyvfey|zuduyliyzplniksrpatwo|'])


# LOAD ASR (PRETRAINED)

In [31]:
from templates.speech_recognition_CharTokens_NoLM.ASR.train import ASR
from templates.speech_recognition_CharTokens_NoLM.ASR.train import dataio_prepare
from torch.utils.data import DataLoader
from speechbrain.dataio.dataloader import LoopedLoader

In [33]:
# Load hyperparameters file with command-line overrides
speechbrain_hparams_file = hparams['speechbrain_hparams_file']
with open(speechbrain_hparams_file) as f:
    speechbrain_hparams = hyperpyyaml.load_hyperpyyaml(f)

/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/rirs_noises.zip exists. Skipping download


In [34]:
speechbrain_hparams['save_folder']

'/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/ASR/results/CRDNN_CHAR_LJSpeech_halved/2602/save'

In [35]:
# initialise trainer (we don't want to train, but model is tightly coupled with trainer)
asr_brain = ASR(
    modules=speechbrain_hparams["modules"],
    opt_class=speechbrain_hparams["opt_class"],
    hparams=speechbrain_hparams,
    checkpointer=speechbrain_hparams["checkpointer"],
)

def setup_asr_brain_for_infer(asr_brain):
    asr_brain.on_evaluate_start(min_key="WER") # We call the on_evaluate_start that will load the best model
    asr_brain.modules.eval() # We set the model to eval mode (remove dropout etc)

print("if on_evaluate_start() get runtime error, likely need to restart notebook kernel")
setup_asr_brain_for_infer(asr_brain)

if on_evaluate_start() get runtime error, likely need to restart notebook kernel


In [36]:
# create dataset and dataloader for inference
datasets = dataio_prepare(speechbrain_hparams)

test_set = datasets['test']

if not isinstance(test_set, DataLoader) or isinstance(test_set, LoopedLoader):
    test_loader_kwargs=speechbrain_hparams["test_dataloader_opts"]
    test_set = asr_brain.make_dataloader(
        test_set, stage=sb.Stage.TEST, **test_loader_kwargs
    )

In [37]:
# get vocab from tokenizer (needed for ctc decoding)
vocab_size = len(asr_brain.hparams.tokenizer)
vocab = []
for i in range(vocab_size):
    vocab.append(asr_brain.hparams.tokenizer.decode_ids([i]))
print(vocab)

# edit vocab to match default ctc decoder symbols for blank and silence
vocab[0] = '-'
vocab[1] = "|"

print(vocab)

[' ⁇ ', '', 'e', 't', 'o', 'a', 'n', 'i', 's', 'r', 'h', 'd', 'l', 'c', 'f', 'u', 'm', 'w', 'p', 'g', 'y', 'b', 'v', 'k', 'x', 'q', 'j', 'z']
['-', '|', 'e', 't', 'o', 'a', 'n', 'i', 's', 'r', 'h', 'd', 'l', 'c', 'f', 'u', 'm', 'w', 'p', 'g', 'y', 'b', 'v', 'k', 'x', 'q', 'j', 'z']


In [38]:
ctc_beamsearch_decoder = ctc_decoder(
    lexicon=None,
    # tokens="/home/s1785140/rlspeller/templates/speech_recognition_CharTokens_NoLM/Tokenizer/save/tokens.txt",
    tokens=vocab,
    nbest=100,
    blank_token='-',
    sil_token="|",
)

In [39]:
# generate transcriptions for all batches in test set
def transcribe_dataset(asr_brain, dataset, greedy=False, num_batches_to_transcribe=None):
    # Now we iterate over the dataset and we simply compute_forward and decode
    with torch.no_grad():
        transcripts = []
        for batch in tqdm(list(dataset)[:num_batches_to_transcribe], dynamic_ncols=True):
            orig_transcriptions = batch.words

            # Make sure that your compute_forward returns the predictions !!!
            # In the case of the template, when stage = TEST, a beam search is applied 
            # in compute_forward(). 
            predictions = asr_brain.compute_forward(batch, stage=sb.Stage.TEST)
            
            ctc_probs = predictions['ctc_logprobs'] # FOR DEBUG

            if greedy:
                predicted_ids = sb.decoders.ctc_greedy_decode(
                    predictions["ctc_logprobs"], asr_brain.feat_lens, blank_id=asr_brain.hparams.blank_index
                )
                predicted_words = [
                    asr_brain.tokenizer.decode_ids(ids).split(" ")
                    for ids in predicted_ids
                ]
            else:
                # get mel lens from wav len ratios since torch ctc decoder requires lens in frames
                batch_max_len = predictions["ctc_logprobs"].size(1)
                bsz = predictions["ctc_logprobs"].size(0)
                mel_lens = torch.zeros(bsz)
                for i, len_ratio in enumerate(asr_brain.feat_lens):
                    mel_lens[i] = int(torch.round(len_ratio * batch_max_len))
                
                predicted_ids = ctc_beamsearch_decoder(
                    predictions["ctc_logprobs"], lengths=mel_lens
                )

                predicted_words = []
                for i, (utt_id, orig_text, hyps) in enumerate(zip(batch.utt_id, orig_transcriptions, predicted_ids)):
                    print(f"\nsample {i+1} - ({utt_id}: '{orig_text}')")
                    sample_cers = []
                    for j, hyp in enumerate(hyps):
                        words = asr_brain.hparams.tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        # words = tokenizer.decode_ids(hyp.tokens.tolist()) # .split("|")
                        hyp_cer = 100 * cer(orig_text, words)
                        sample_cers.append(hyp_cer)
                        print(f"\thyp {j+1}/{len(hyps)} (CER={hyp_cer:.1f}%): '{words}'")
                        predicted_words.append((f"sample {i+1}, hyp {j+1}/{len(hyps)}", words))
                        
                    print(f"\t=== Mean CER: {np.mean(sample_cers):.1f}%, Std CER: {np.std(sample_cers):.1f}% ===")

            transcripts.append(predicted_words)

    return transcripts, ctc_probs

transcripts, ctc_probs = transcribe_dataset(asr_brain, test_set, greedy=False, num_batches_to_transcribe=1)

  0%|                                                                                        | 0/1 [00:00<?, ?it/s]

DEBUG INSIDE PREPARE FEATURES, feats.shape=torch.Size([8, 627, 40]) wav_lens.shape=torch.Size([8])


100%|████████████████████████████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.73s/it]


sample 1 - (LJ039-0175: 'for the first four attempts the firers missed the second shot by several inches')
	hyp 1/48 (CER=0.0%): 'for the first four attempts the firers missed the second shot by several inches '
	hyp 2/48 (CER=1.3%): 'for the first four attempts the firers mised the second shot by several inches '
	hyp 3/48 (CER=1.3%): 'for the fist four attempts the firers missed the second shot by several inches '
	hyp 4/48 (CER=1.3%): 'for the first four attempths the firers missed the second shot by several inches '
	hyp 5/48 (CER=2.5%): 'for the first four attempts the firerers missed the second shot by several inches '
	hyp 6/48 (CER=2.5%): 'for the fist four attempts the firers mised the second shot by several inches '
	hyp 7/48 (CER=1.3%): 'fr the first four attempts the firers missed the second shot by several inches '
	hyp 8/48 (CER=1.3%): 'for the first four attempts the firers  missed the second shot by several inches '
	hyp 9/48 (CER=1.3%): 'for the first four atempts the




# LOAD WORD ALIGNED MELS

In [40]:
# imitate CLAs
import sys
import argparse
import math
import glob
from tqdm import tqdm

In [41]:
sys.argv = [
    # speechbrain features
    'train.py',
    '--type', 'mel',
    '--utt_id_list', '/home/s1785140/data/ljspeech_fastpitch/respeller_uttids.txt', 
    '--input_directory', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats',
    '--alignments', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/LJSpeech-1.1/MFA_alignments_lowercase_nopunc', # newer alignments, lowercase no punctuation
    '--output_directory', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned',
    # '--mel-to-graphemes-ratio-lowest-threshold', '5.5',
    # '--mel-to-graphemes-ratio-highest-threshold', '13.5',
    '--clean-output-folder',
]

parser = argparse.ArgumentParser()
parser.add_argument('-t', '--type', type=str, default='hubert',
                    help='type of input speech reps that we are using, i.e. hubert wav2vec2 etc.')
parser.add_argument('--padding_idx_offset', type=int, default=0,
                    help='add 1 to token id of discrete reps in order to allow for padding_idx==0')
parser.add_argument('--utt_id_list', type=str, required=False, default="",
                    help='path to text file that contains list of utterance ids that we extract from')
parser.add_argument('-s', '--input_directory', type=str, required=True,
                    help='path to single non-nested folder containing speech representations (.pt files) or txt file (hubert)')
parser.add_argument('-a', '--alignments', type=str, required=True,
                    help='path to single non-nested folder containing MFA alignments (.TextGrid files)')
parser.add_argument('-o', '--output_directory', type=str, required=True,
                    help='where to write word-level data')
parser.add_argument('--max-utts-to-generate', type=int, default=None,
                    help='How many utts to extract word aligned speech reps for. If None, extract all utts.')
parser.add_argument('--mel-to-graphemes-ratio-lowest-threshold', type=float, default=0.0,
                    help='Lowest mel-to-graphemes ratio to consider. (lower ratio means fewer mel frames per grapheme)')
parser.add_argument('--mel-to-graphemes-ratio-highest-threshold', type=float, default=math.inf,
                    help='Lowest mel-to-graphemes ratio to consider. (higher ratio means more mel frames per grapheme)')
parser.add_argument('--clean-output-folder', action="store_true",
                    help='Clean output folder before writing new data')
args = parser.parse_args()

In [42]:
# grab all mels and words in the output directory
'/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned',

# glob pytorch tensors from nested folders in output directory
mel_paths = glob.glob(f'{args.output_directory}/**/*.pt', recursive=True)
print(f"globbed {len(mel_paths)} mels from {args.output_directory}.")


globbed 62116 mels from /home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned.


In [44]:
MAX_TOKENS_TO_TRANSCRIBE = 5

# load mels into list
mel_paths = []
mels = []
words = []
for mel_path in tqdm(list(mel_paths)[:MAX_TOKENS_TO_TRANSCRIBE]):
    mel_paths.append(mel_path)

    mel = torch.load(mel_path)
    mels.append(mel)

    # also get word from path
    word = mel_path.split('/')[-2]
    words.append(word)

0it [00:00, ?it/s]


# TRANSCRIBE WORD ALIGNED MELS