In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

# dev train loop + model for respeller

# command line args

In [None]:
# imitate CLAs
import sys
sys.argv = [
    'train.py',
    '--fastpitch-chkpt', 'fastpitch/exps/halved_ljspeech_data/FastPitch_checkpoint_1000.pt',
    '--input-type', 'char',
    '--symbol-set', 'english_basic_lowercase_no_arpabet',
    '--use-mas',
    '--cuda',
    '--n-speakers', '1',
    '--use-sepconv',
    '--add-spaces',
    '--eos-symbol', '$',
    '--epochs', '10000', # NB small number for development!
    '--batch-size', '32',
    '--chkpt-save-dir', '/home/s1785140/respeller/exps/test', 
    '--val-log-interval', '10',
    '--learning-rate', '0.1',
    # '--resume', # resume from latest checkpoint
]

# Imports

In [None]:
'''
Train respeller model

We backpropagate loss from pretrained TTS model to a Grapheme-to-Grapheme (G2G) respeller model to help it respell words
into a simpler form

Intermediated respellings are discrete character sequences
We can backpropagate through these using gumbel softmax and the straight through estimator
'''
import argparse
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import json
import glob
import re
from collections import defaultdict

from fastpitch import models as fastpitch_model
from fastpitch.common.text.text_processing import TextProcessor

from modules.model import EncoderRespeller
from modules.gumbel_vector_quantizer import GumbelVectorQuantizer
from modules.sdtw_cuda_loss import SoftDTW

import librosa
import librosa.display
import matplotlib.pyplot as plt

# Functions

## arguments to parse

In [None]:
def parse_args(parser):
    """Parse commandline arguments"""
    parser.add_argument('-o', '--chkpt-save-dir', type=str, required=True,
                        help='Directory to save checkpoints')
    parser.add_argument('-d', '--dataset-path', type=str, default='./',
                        help='Path to dataset')
    parser.add_argument('--log-file', type=str, default=None,
                        help='Path to a DLLogger log file')

    train = parser.add_argument_group('training setup')
    train.add_argument('--cuda', action='store_true',
                      help='Enable GPU training')
    train.add_argument('--batch-size', type=int, default=16,
                      help='Batchsize (this is divided by number of GPUs if running Data Distributed Parallel Training)')
    train.add_argument('--seed', type=int, default=1337,
                       help='Seed for PyTorch random number generators')
    train.add_argument('--grad-accumulation', type=int, default=1,
                       help='Training steps to accumulate gradients for')
    train.add_argument('--epochs', type=int, default=100, #required=True,
                       help='Number of total epochs to run')
    train.add_argument('--epochs-per-checkpoint', type=int, default=10,
                       help='Number of epochs per checkpoint')
    train.add_argument('--checkpoint-path', type=str, default=None,
                       help='Checkpoint path to resume train')
    train.add_argument('--resume', action='store_true',
                       help='Resume train from the last available checkpoint')
    train.add_argument('--val-log-interval', type=int, default=5,
                       help='How often to generate melspecs/audio for respellings and log to wandb')
    
    opt = parser.add_argument_group('optimization setup')
    opt.add_argument('--optimizer', type=str, default='lamb', choices=['adam', 'lamb'],
                     help='Optimization algorithm')
    opt.add_argument('-lr', '--learning-rate', default=0.1, type=float,
                     help='Learning rate')
    opt.add_argument('--weight-decay', default=1e-6, type=float,
                     help='Weight decay')
    opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
                     help='Clip threshold for gradients')
    opt.add_argument('--warmup-steps', type=int, default=1000,
                     help='Number of steps for lr warmup')

    arch = parser.add_argument_group('architecture')
    arch.add_argument('--d-model', type=int, default=512,
                       help='Hidden dimension of tranformer')
    arch.add_argument('--latent-temp', type=tuple, default=(2, 0.5, 0.999995),
                       help='Temperature annealling parameters for Gumbel-Softmax (start, end, decay)')

    pretrained_tts = parser.add_argument_group('pretrained tts model')
    # pretrained_tts.add_argument('--fastpitch-with-mas', type=bool, default=True,
    #                   help='Whether or not fastpitch was trained with Monotonic Alignment Search (MAS)')
    pretrained_tts.add_argument('--fastpitch-chkpt', type=str, required=True,
                      help='Path to pretrained fastpitch checkpoint')
    pretrained_tts.add_argument('--input-type', type=str, default='char',
                      choices=['char', 'phone', 'pf', 'unit'],
                      help='Input symbols used, either char (text), phone, pf '
                      '(phonological feature vectors) or unit (quantized acoustic '
                      'representation IDs)')
    pretrained_tts.add_argument('--symbol-set', type=str, default='english_basic_lowercase',
                      help='Define symbol set for input sequences. For quantized '
                      'unit inputs, pass the size of the vocabulary.')
    pretrained_tts.add_argument('--n-speakers', type=int, default=1,
                      help='Condition on speaker, value > 1 enables trainable '
                      'speaker embeddings.')
    # pretrained_tts.add_argument('--use-sepconv', type=bool, default=True,
    #                   help='Use depthwise separable convolutions')
    
    audio = parser.add_argument_group('log generated audio')
    audio.add_argument('--hifigan', type=str, default='/home/s1785140/pretrained_models/hifigan/ljspeech/LJ_V1/generator_v1',
                       help='Path to HiFi-GAN audio checkpoint')
    audio.add_argument('--hifigan-config', type=str, default='/home/s1785140/pretrained_models/hifigan/ljspeech/LJ_V1/config.json',
                       help='Path to HiFi-GAN audio config file')
    audio.add_argument('--sampling-rate', type=int, default=22050,
                       help='Sampling rate for output audio')
    audio.add_argument('--hop-length', type=int, default=256,
                       help='STFT hop length for estimating audio length from mel size')
    
    data = parser.add_argument_group('dataset parameters')
    cond = parser.add_argument_group('conditioning on additional attributes')
    dist = parser.add_argument_group('distributed training setup')

    return parser

def load_checkpoint(args, model, filepath):
    if args.local_rank == 0:
        print(f'Loading model and optimizer state from {filepath}')
    checkpoint = torch.load(filepath, map_location='cpu')
    sd = {k.replace('module.', ''): v
          for k, v in checkpoint['state_dict'].items()}
    getattr(model, 'module', model).load_state_dict(sd)
    return model

def load_respeller_checkpoint(args, model, filepath, optimizer, epoch, total_iter):
    if args.local_rank == 0:
        print(f'Loading model and optimizer state from {filepath}')
    checkpoint = torch.load(filepath, map_location='cpu')
    epoch[0] = checkpoint['epoch'] + 1
    total_iter[0] = checkpoint['iteration']
    sd = {k.replace('module.', ''): v
          for k, v in checkpoint['state_dict'].items()}
    getattr(model, 'module', model).load_state_dict(sd)
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model

def last_checkpoint(output):
    saved = sorted(
        glob.glob(f'{output}/respeller_checkpoint_*.pt'),
        key=lambda f: int(re.search('_(\d+).pt', f).group(1)))

    def corrupted(fpath):
        try:
            torch.load(fpath, map_location='cpu')
            return False
        except:
            warnings.warn(f'Cannot load {fpath}')
            return True

    if len(saved) >= 1 and not corrupted(saved[-1]):
        return saved[-1]
    elif len(saved) >= 2:
        return saved[-2]
    else:
        return None

def init_embedding_weights(source_tensor, target_tensor):
    """copy weights inplace from source tensor to target tensor"""
    target_tensor.requires_grad = False
    target_tensor.copy_(source_tensor.clone().detach())
    target_tensor.requires_grad = True

def load_pretrained_fastpitch(args):
    # load chkpt
    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = fastpitch_model.get_model_config('FastPitch', args)
    fastpitch = fastpitch_model.get_model('FastPitch', model_config, device, forward_is_infer=True)
    load_checkpoint(args, fastpitch, args.fastpitch_chkpt)
    # get information about grapheme embedding table
    n_symbols = fastpitch.encoder.word_emb.weight.size(0)
    grapheme_embedding_dim = fastpitch.encoder.word_emb.weight.size(1)
    return fastpitch, n_symbols, grapheme_embedding_dim, model_config

# beginning of main(), parse Command Line Args

In [None]:
parser = argparse.ArgumentParser(description='PyTorch Respeller Training', allow_abbrev=False)
parser = parse_args(parser)
args, _unk_args = parser.parse_known_args()

parser = fastpitch_model.parse_model_args('FastPitch', parser)
args, unk_args = parser.parse_known_args()
if len(unk_args) > 0:
    raise ValueError(f'Invalid options {unk_args}')

if args.cuda:
    args.num_gpus = torch.cuda.device_count()
    args.distributed_run = args.num_gpus > 1
    args.batch_size = int(args.batch_size / args.num_gpus)
else:
    args.distributed_run = False

torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.distributed_run:
    mp.spawn(train, nprocs=args.num_gpus, args=(args,))

## WANDB - weights and biases init

In [None]:
import wandb
wandb.login() # needed for wandb integration with jupyter notebook

In [None]:
%env "WANDB_NOTEBOOK_NAME" "respeller-dev-train-ipynb"

In [None]:
from datetime import datetime

# datetime object containing current date and time
now = datetime.now()
 
print("now =", now)

# dd/mm/YY H:M:S
dt_string = now.strftime("%d/%m/%Y_%H:%M:%S")
print("date and time =", dt_string)

In [None]:
run = 0

# store important information into WANDB config for easier tracking of experiments
# add all key values from parser
wandb_config = vars(args)
wandb.init(
    project="respeller-dev-train-ipynb",
    name=f"experiment_{run}_{dt_string}",
    config=wandb_config,
)

# 'train()' - forward pass through model to get loss

## create / load models 

In [None]:
rank = 0
device = 'cuda'

args.local_rank = rank
tts, n_symbols, grapheme_embedding_dim, model_config = load_pretrained_fastpitch(args)
tts.to(device)

respeller = EncoderRespeller(n_symbols=n_symbols, pretrained_tts=tts, d_model=args.d_model)
respeller.to(device)

# quantiser = GumbelVectorQuantizer(
#     in_dim=args.d_model,
#     codebook_size=n_symbols,  # number of codebook entries
#     embedding_dim=grapheme_embedding_dim,
#     temp=args.latent_temp,
# )
# quantiser.to(device)

# init_embedding_weights(tts.encoder.word_emb.weight.unsqueeze(0), quantiser.vars)


# batch_size, len_x, len_y, dims = 8, 15, 12, 5
# x = torch.rand((batch_size, len_x, dims), requires_grad=True)
# y = torch.rand((batch_size, len_y, dims))

# criterion = SoftDTW(use_cuda=True, gamma=0.1, dist_func=F.mse_loss)
criterion = SoftDTW(use_cuda=True, gamma=0.1)
# input should be size [bsz, seqlen, dim]
criterion.to(device)

### load HiFiGAN vocoder 

In [None]:
def load_vocoder(args, device):
    """Load HiFi-GAN vocoder from checkpoint"""
    checkpoint_data = torch.load(args.hifigan)
    vocoder_config = fastpitch_model.get_model_config('HiFi-GAN', args)
    vocoder = fastpitch_model.get_model('HiFi-GAN', vocoder_config, device)
    vocoder.load_state_dict(checkpoint_data['generator'])
    vocoder.remove_weight_norm()
    vocoder.eval()
    return vocoder

In [None]:
vocoder = load_vocoder(args, device)

## forward pass through model with dummy data

In [None]:
batches = []
symbol_set = 'english_basic_lowercase'
text_cleaners = []
gt_log_mel = torch.load('/home/s1785140/data/ljspeech_fastpitch/mels/LJ001-0001.pt').cuda().unsqueeze(0).transpose(1,2) # intro batch dimension + [bsz, seqlen, dim]
raw_text = 'printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the exhibition'

# process text using same processor as fastpitch
tp = TextProcessor(symbol_set, text_cleaners)
text = torch.LongTensor(tp.encode_text(raw_text)).unsqueeze(0).cuda()

batches.append((text, gt_log_mel))

In [None]:
text.size()

In [None]:
text.device

In [None]:
# for batch in batches:
batch = batches[0]
    
###############################################################################################################
# text, ssl_reps, e2e_asr_predictions, gt_log_mel = batch
text, gt_log_mel = batch

###############################################################################################################
# create inputs
# if args.use_acoustic_input:
#     inputs = inputs.concat(ssl_reps)

###############################################################################################################
# forward pass
g_embeddings, g_embedding_indices = respeller(text[:13])

In [None]:
n_symbols

In [None]:
g_embedding_indices.size()

In [None]:
g_embeddings.size()

In [None]:
padding_idx = 0
mask = (g_embedding_indices != padding_idx).unsqueeze(2)
mask.size()

In [None]:
log_mel, dec_lens, _dur_pred, _pitch_pred = tts(g_embeddings, skip_embeddings=True, ids=g_embedding_indices)
# log_mel [bsz, dim, seqlen]
log_mel = log_mel.transpose(1,2)
# log_mel [bsz, seqlen, dim]

print(f'{log_mel.size()=}')
print(f'{gt_log_mel.size()=}')

###############################################################################################################
# calculate val_losses
# respelling_loss = respelling_loss_fn(respelling, e2e_asr_predictions)
acoustic_loss = criterion(log_mel, gt_log_mel)

# average loss over frames 
acoustic_loss = acoustic_loss / dec_lens
# mel_loss = (mel_loss * mel_mask).sum() / mel_mask.sum()

###############################################################################################################
# backward pass
loss = acoustic_loss 

print(f'{loss=}')

# loss.backward()

###############################################################################################################
# log tensorboard metrics

###############################################################################################################
# validation set evaluation

In [None]:
def plot_spectrogram(log_mel, figsize=(15,5), wandb_log=False, image_name=""):
    plt.figure(figsize=figsize)
    librosa.display.specshow(log_mel, x_axis='frames', y_axis='linear')
    plt.colorbar()
    if wandb_log:
        wandb.log({image_name: wandb.Image(plt, caption=image_name)})

In [None]:
batch_index = 0
print(f'{log_mel[batch_index].transpose(0,1).size()=}')
plot_spectrogram(log_mel[batch_index].transpose(0,1).detach().cpu().numpy())

In [None]:
# # play audio
# import IPython.display as ipd
# audio = vocoder(log_mel[batch_index].transpose(0,1).detach().unsqueeze(0))
# ipd.Audio(audio, rate=22050)

In [None]:
batch_index = 0
plot_spectrogram(gt_log_mel[batch_index].transpose(0,1).detach().cpu().numpy())

# develop respeller dataset

In [None]:
import torch
import os
from collections import Counter
from tqdm import tqdm

In [None]:
wordaligned_speechreps_dir = '/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels' # path to directory that contains folders of word aligned speech reps
wordlist = ['identifies','mash','player','russias','techniques'] # txt file for the words to include speech reps 

In [None]:
token_and_melfilepaths = []
for word in wordlist:
    # find all word aligned mels for the word
    word_dir = os.path.join(wordaligned_speechreps_dir, word)
    mel_files = os.listdir(word_dir)
    for mel_file in mel_files:
        mel_file_path = os.path.join(word_dir, mel_file)
        token_and_melfilepaths.append((word, mel_file_path))
    

In [None]:
token_and_melfilepaths

## process text

In [None]:
from fastpitch.common.text.text_processing import TextProcessor
text_cleaners = []
symbol_set = "english_basic_lowercase_no_arpabet"
tp = TextProcessor(symbol_set, text_cleaners, add_spaces=True, eos_symbol="$")

In [None]:
encoded = torch.IntTensor(tp.encode_text('identifies'))
encoded

In [None]:
decoded = [tp.id_to_symbol[id] for id in encoded.tolist()]
decoded

## process mel

In [None]:
word, fp = token_and_melfilepaths[0]
wordaligned_mel = torch.load(fp)
wordaligned_mel.size() # [seqlen, feats]

## 'class'-ified dataset class

In [None]:
from fastpitch.common.text.text_processing import TextProcessor

class RespellerDataset(torch.utils.data.Dataset):
    """
        1) loads word + word-aligned mel spec for all words in a wordlist
        2) converts text to sequences of one-hot vectors (corresponding to grapheme indices in fastpitch)
    """
    def __init__(
        self, 
        wordaligned_speechreps_dir, # path to directory that contains folders of word aligned speech reps
        wordlist, # txt file for the words to include speech reps from
        max_examples_per_wordtype=None,
        text_cleaners=[],
        symbol_set="english_basic_lowercase_no_arpabet",
        add_spaces=True,
        eos_symbol="$",
        **kwargs,
    ):
        # load wordlist as a python list
        if type(wordlist) == str:
            if wordlist.endswith('.json'):
                with open(wordlist) as f:
                    wordlist = json.load(f)
            else:
                with open(wordlist) as f:
                    wordlist = f.read().splitlines()
        elif type(wordlist) == list:
            pass # dont need to do anything, already in expected form
        elif type(wordlist) == set:
            wordlist = list(wordlist)
        
        wordlist = sorted(wordlist)
        
        # create list of all word tokens and their word aligned speech reps
        self.word_freq = Counter()
        self.token_and_melfilepaths = []
        print("Initialising respeller dataset")
        for word in tqdm(wordlist):
            # find all word aligned mels for the word
            word_dir = os.path.join(wordaligned_speechreps_dir, word)
            mel_files = os.listdir(word_dir)
            if max_examples_per_wordtype:
                mel_files = mel_files[:max_examples_per_wordtype]
            for mel_file in mel_files:
                mel_file_path = os.path.join(word_dir, mel_file)
                self.token_and_melfilepaths.append((word, mel_file_path))
                self.word_freq[word] += 1
                
        self.tp = TextProcessor(symbol_set, text_cleaners, add_spaces=add_spaces, eos_symbol=eos_symbol)

    def get_mel(self, filename):
        return torch.load(filename)

    def encode_text(self, text):
        """encode raw text into indices defined by grapheme embedding table of the TTS model"""
        return torch.IntTensor(self.tp.encode_text(text))
    
    def decode_text(self, encoded):
        return [self.tp.id_to_symbol[id] for id in encoded.tolist()]
    
    @staticmethod
    def get_mel_len(melfilepath):
        return int(melfilepath.split('seqlen')[1].split('.pt')[0])

    def __getitem__(self, index):
        word, mel_filepath = self.token_and_melfilepaths[index]
        encoded_word = self.encode_text(word)
        mel = self.get_mel(mel_filepath)
        
        return {
            'word': word, 
            'encoded_word': encoded_word, 
            'mel_filepath': mel_filepath,
            'mel': mel,
        }

    def __len__(self):
        return len(self.token_and_melfilepaths)

In [None]:
dataset = RespellerDataset(
    wordaligned_speechreps_dir='/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels', # path to directory that contains folders of word aligned speech reps
    wordlist=['identifies','mash','player','russias','techniques'],
)

In [None]:
batch = []

for itemdict in dataset:
    # unpack dict
    word = itemdict['word'] 
    encoded_word = itemdict['encoded_word'] 
    mel = itemdict['mel'] 
    
    # check
    print(word, encoded_word, mel.size())
    
    batch.append(itemdict)

## collate function

In [None]:
class Collate():
    """ Zero-pads model inputs and targets based on number of frames per setep
    """
    # def __init__(self):
    
    def __call__(self, batch):
        """Collate's training batch from encoded word token and its 
        corresponding word-aligned mel spectrogram
        
        batch: [encoded_token, wordaligned_mel]
        """
        # Right zero-pad all one-hot text sequences to max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x['encoded_word']) for x in batch]),
            dim=0, descending=True)
        max_input_len = input_lengths[0]

        words = []
        mel_filepaths = []
        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        text_lengths = torch.LongTensor(len(batch))
        for i in range(len(ids_sorted_decreasing)):
            words.append(batch[ids_sorted_decreasing[i]]['word'])
            mel_filepaths.append(batch[ids_sorted_decreasing[i]]['mel_filepath'])
            text = batch[ids_sorted_decreasing[i]]['encoded_word']
            text_padded[i, :text.size(0)] = text
            text_lengths[i] = text.size(0)

        # Right zero-pad mel-spec
        num_mels = batch[0]['mel'].size(1)
        max_target_len = max([x['mel'].size(0) for x in batch])

        mel_padded = torch.FloatTensor(len(batch), max_target_len, num_mels)
        mel_padded.zero_()
        mel_lengths = torch.LongTensor(len(batch))
        for i in range(len(ids_sorted_decreasing)):
            mel = batch[ids_sorted_decreasing[i]]['mel']
            mel_padded[i, :mel.size(0), :] = mel
            mel_lengths[i] = mel.size(0)
            

        return {
            'words': words,
            'text_padded': text_padded,
            'text_lengths': text_lengths,
            'mel_padded': mel_padded, 
            'mel_lengths': mel_lengths,
            'mel_filepaths': mel_filepaths
        }
                # input_lengths, mel_padded, output_lengths,
                # len_x, dur_padded, dur_lens, pitch_padded, speaker)

In [None]:
collate_fn = Collate()
collated = collate_fn(batch)

In [None]:
collated['text_padded'].size()

In [None]:
collated['text_padded']

In [None]:
collated['words']

In [None]:
collated['mel_padded'].size()

## put batch on gpu

In [None]:
def to_gpu(x):
    x = x.contiguous()
    if torch.cuda.is_available():
        x = x.cuda(non_blocking=True)
    return torch.autograd.Variable(x)

def batch_to_gpu(collated_batch):
    """put elements that are used throughout training onto gpu"""
    words = collated_batch['words']
    text_padded = collated_batch['text_padded']
    text_lengths = collated_batch['text_lengths']
    mel_padded = collated_batch['mel_padded']
    mel_lengths = collated_batch['mel_lengths']
    
    # no need to put words on gpu, its only used during eval loop
    text_padded = to_gpu(text_padded).long()
    text_lengths = to_gpu(text_lengths).long()
    mel_padded = to_gpu(mel_padded).float()
    mel_lengths = to_gpu(mel_lengths).long()
    
    # x: inputs
    x = {
        'words': words,
        'text_padded': text_padded,
        'text_lengths': text_lengths,
    }
    # y: targets
    y = {
        'mel_padded': mel_padded, 
        'mel_lengths': mel_lengths,
    }
    
    return (x, y)

In [None]:
# batch_to_gpu(collated)

# full train + dev datasets

In [None]:
train_dataset = RespellerDataset(
    wordaligned_speechreps_dir='/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels', # path to directory that contains folders of word aligned speech reps
    wordlist='/home/s1785140/data/ljspeech_fastpitch/respeller_train_words.json',
    max_examples_per_wordtype=2,
)

In [None]:
len(train_dataset)

In [None]:
sum(train_dataset.word_freq.values())

In [None]:
val_dataset = RespellerDataset(
    wordaligned_speechreps_dir='/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels', # path to directory that contains folders of word aligned speech reps
    wordlist='/home/s1785140/data/ljspeech_fastpitch/respeller_dev_words.json',
)

In [None]:
len(val_dataset)

In [None]:
sum(val_dataset.word_freq.values())

# create torch dataloader

In [None]:
from torch.utils.data import DataLoader

In [None]:
# TODO, implement distributed training?
train_sampler = None
shuffle = True
num_cpus = 1 
train_loader = DataLoader(train_dataset, num_workers=2*num_cpus, shuffle=shuffle,
                          sampler=train_sampler, batch_size=args.batch_size,
                          pin_memory=False, drop_last=True,
                          collate_fn=collate_fn)

In [None]:
# for batch in train_loader:
#     print(batch)

# FULL train() loop

## init dl logger

In [None]:
import fastpitch.common.tb_dllogger as logger

def touch_file(path):
    if not os.path.exists(path):
        basedir = os.path.dirname(path)
        os.makedirs(basedir, exist_ok=True)
        with open(path, 'w') as f:
            f.write("")

# initialise logger
tb_subsets = ['train', 'val']
log_fpath = args.log_file or os.path.join(args.chkpt_save_dir, 'nvlog.json')
touch_file(log_fpath)

try: 
    logger.init(log_fpath, args.chkpt_save_dir, enabled=(args.local_rank == 0),
                tb_subsets=tb_subsets)
    logger.parameters(vars(args), tb_subset='train')
except:
    print("WARNING DLLLoggerAlreadyInitialized error raised")

## imports

In [None]:
from torch_optimizer import Lamb
import time
from fastpitch.common.utils import mask_from_lens
from collections import OrderedDict

## functions

In [None]:
def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
    if warmup_iters == 0:
        scale = 1.0
    elif total_iter > warmup_iters:
        scale = 1. / (total_iter ** 0.5)
    else:
        scale = total_iter / (warmup_iters ** 1.5)

    for param_group in opt.param_groups:
        param_group['lr'] = learning_rate * scale

In [None]:
def log_stdout(logger, subset, epoch_iters, total_steps, loss, took):
    logger_data = [
        ('Loss/Total', loss),
    ]
    logger_data.append(('Time/Iter time', took))
    logger.log(epoch_iters,
               tb_total_steps=total_steps,
               subset=subset,
               data=OrderedDict(logger_data)
    )

In [None]:
def maybe_save_checkpoint(args, model, optimizer, epoch,
                          total_iter, config):
    if args.local_rank != 0:
        return

    intermediate = (args.epochs_per_checkpoint > 0
                    and epoch % args.epochs_per_checkpoint == 0)

    if not intermediate and epoch < args.epochs:
        return

    fpath = os.path.join(args.chkpt_save_dir, f"respeller_checkpoint_{epoch}.pt")
    print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
    checkpoint = {'epoch': epoch,
                  'iteration': total_iter,
                  'config': config,
                  'state_dict': model.state_dict(),
                  'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, fpath)

## pre-training loop stuff

In [None]:
# def train(rank, args):


# handle GPU
rank = 0
args.local_rank = rank
device = torch.device('cuda' if args.cuda else 'cpu')

# load models
tts, n_symbols, grapheme_embedding_dim, model_config = load_pretrained_fastpitch(args)
respeller = EncoderRespeller(n_symbols=n_symbols, pretrained_tts=tts, d_model=args.d_model)
# quantiser = GumbelVectorQuantizer(
#     in_dim=args.d_model,
#     codebook_size=n_symbols,  # number of codebook entries
#     embedding_dim=grapheme_embedding_dim,
#     temp=args.latent_temp,
# )
# init_embedding_weights(tts.encoder.word_emb.weight.unsqueeze(0), quantiser.vars)
criterion = SoftDTW(use_cuda=True, gamma=0.1) # input should be size [bsz, seqlen, dim]

tts.to(device)
respeller.to(device)
# quantiser.to(device)
criterion.to(device)

# load optimiser and assign to it the weights to be trained
kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
          weight_decay=args.weight_decay)
optimizer = Lamb(respeller.trainable_parameters(), **kw)

# (optional) load checkpoint for respeller
start_epoch = [1]
start_iter = [0]
assert args.checkpoint_path is None or args.resume is False, (
    "Specify a single checkpoint source")
if args.checkpoint_path is not None:
    ch_fpath = args.checkpoint_path
elif args.resume:
    ch_fpath = last_checkpoint(args.chkpt_save_dir)
else:
    ch_fpath = None
if ch_fpath is not None:
    load_respeller_checkpoint(args, respeller, ch_fpath, optimizer, start_epoch, start_iter)
    
start_epoch = start_epoch[0]
total_iter = start_iter[0]
    
# create datasets, collate func, dataloader
train_dataset = RespellerDataset(
    wordaligned_speechreps_dir='/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels', # path to directory that contains folders of word aligned speech reps
    wordlist='/home/s1785140/data/ljspeech_fastpitch/respeller_train_words.json',
    max_examples_per_wordtype=2,
)
val_dataset = RespellerDataset(
    wordaligned_speechreps_dir='/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels', # path to directory that contains folders of word aligned speech reps
    wordlist='/home/s1785140/data/ljspeech_fastpitch/respeller_dev_words.json',
)
num_cpus = 1 # TODO change to CLA?
train_loader = DataLoader(train_dataset, num_workers=2*num_cpus, shuffle=True,
                          sampler=None, batch_size=args.batch_size,
                          pin_memory=False, drop_last=True,
                          collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, num_workers=2*num_cpus, shuffle=False,
                          sampler=None, batch_size=args.batch_size,
                          pin_memory=False, collate_fn=collate_fn)

# load pretrained hifigan

# log spectrograms and generated audio for first few validation wordtypes

# train loop
respeller.train()
# quantiser.train()
tts.eval()

print('Finished setting up models + dataloaders')

## validate() fn

In [None]:
import PIL
import plotly

def log_spectrogram(log_mel, figsize=(15,5), image_name=""):
    fig, ax = plt.subplots(figsize=figsize)
    img = librosa.display.specshow(log_mel, ax=ax, x_axis='frames', y_axis='linear')
    ax.set_title(image_name)
    fig.colorbar(img, ax=ax)
    return fig

def get_spectrograms_plots(y, fnames, step, n=4, label='Predicted spectrogram', mas=False):
    """Plot spectrograms for n utterances in batch"""
    bs = len(fnames)
    n = min(n, bs)
    s = bs // n
    fnames = fnames[::s]
    # print(f"inside get_spectrograms_plots(), {fnames=}")
    if label == 'Predicted spectrogram':
        # y: mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
        mel_specs = y[0][::s].transpose(1, 2).cpu().numpy()
        mel_lens = y[1][::s].squeeze().cpu().numpy() - 1
    elif label == 'Reference spectrogram':
        # y: mel_padded, dur_padded, dur_lens, pitch_padded
        mel_specs = y[0][::s].cpu().numpy()
        if mas:
            mel_lens = y[2][::s].cpu().numpy()  # output_lengths
        else:
            mel_lens = y[1][::s].cpu().numpy().sum(axis=1) - 1
            
    image_names = []
    spectrogram_figs = []
    for mel_spec, mel_len, fname in zip(mel_specs, mel_lens, fnames):
        mel_spec = mel_spec[:, :mel_len]
        utt_id = os.path.splitext(os.path.basename(fname))[0]
        # if mode == 'tb':
        #     logger.log_spectrogram_tb(
        #         step, '{}/{}'.format(label, utt_id), mel_spec, tb_subset='val')
        # elif mode == 'wandb':
        image_name = f'val/{label}/{utt_id}'
        fig = log_spectrogram(mel_spec, image_name=image_name)
        image_names.append(image_name)
        spectrogram_figs.append(fig)
            # wandb.log({image_name: 
            #             wandb.Image(plt, caption=image_name)})
            # wandb.log({image_name: plt}) # requires plotly
    return image_names, spectrogram_figs

In [None]:
def generate_audio(y, fnames, step, vocoder=None, sampling_rate=22050, hop_length=256,
                   n=4, label='Predicted audio', mas=False):
    """Generate audio from spectrograms for n utterances in batch"""
    bs = len(fnames)
    n = min(n, bs)
    s = bs // n
    fnames = fnames[::s]
    # print(f"inside generate_audio(), {fnames=}")
    with torch.no_grad():
        if label == 'Predicted audio':
            # y: mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
            audios = vocoder(y[0][::s].transpose(1, 2)).cpu().squeeze().numpy()
            mel_lens = y[1][::s].squeeze().cpu().numpy() - 1
        elif label == 'Copy synthesis':
            # y: mel_padded, dur_padded, dur_lens, pitch_padded
            audios = vocoder(y[0][::s]).cpu().squeeze().numpy()
            if mas:
                mel_lens = y[2][::s].cpu().numpy()  # output_lengths
            else:
                mel_lens = y[1][::s].cpu().numpy().sum(axis=1) - 1
        elif label == 'Reference audio':
            audios = []
            for fname in fnames:
                wav = re.sub(r'mels/(.+)\.pt', r'wavs/\1.wav', fname)
                audio, _ = librosa.load(wav, sr=sampling_rate)
                audios.append(audio)
            if mas:
                mel_lens = y[2][::s].cpu().numpy()  # output_lengths
            else:
                mel_lens = y[1][::s].cpu().numpy().sum(axis=1) - 1
    audios_to_return = []
    for audio, mel_len, fname in zip(audios, mel_lens, fnames):
        audio = audio[:mel_len * hop_length]
        audio = audio / np.max(np.abs(audio))
        utt_id = os.path.splitext(os.path.basename(fname))[0]
        # if mode == 'tb':
        #     logger.log_audio_tb(
        #         step, '{}/{}'.format(label, utt_id), audio, sampling_rate, tb_subset='val')
        # elif mode == 'wandb':
            # audio_name = f"val/{label}/{utt_id}"
            # wandb.log({audio_name:
            #    [wandb.Audio(audio, caption=audio_name, sample_rate=sampling_rate)]})
        audios_to_return.append(audio)
        
    return audios_to_return


In [None]:
def decode_indices(indices):
    """decode batch of indices to text
    [bsz, seqlen]"""
    decodings = []
    for batch_idx in range(indices.size(0)):
        decodings.append(''.join(tp.id_to_symbol[id] for id in indices[batch_idx].tolist()))
    return decodings

def select(x, bsz, n):
    """select items in batch that will be visualised/converted to audio"""
    n = min(n, bsz)
    s = bsz // n
    return x[::s]

def log_wandb_table(
    names, 
    vocoded_gt_audios,
    orig_words,
    respellings,
    orig_pred_spec_figs,
    orig_pred_audios,
    pred_spec_figs,
    pred_audios,
    sampling_rate=22050,
    train=False,
):  
    # define table
    table = wandb.Table(columns=[
        "names", 
        "orig spelling", 
        "orig spelling spec", 
        "orig spelling audio",
        "vocoded gt audio",
        "respelling", 
        "respelling spec", 
        "respelling audio",
    ])
    # add rows to table
    for name, orig_word, orig_pred_spec_fig, orig_pred_audio, vocoded_gt_audio, respelling, pred_spec_fig, pred_audio in zip(
        names, 
        orig_words, 
        orig_pred_spec_figs,
        orig_pred_audios,
        vocoded_gt_audios,
        respellings, 
        pred_spec_figs, 
        pred_audios,
    ):
        table.add_data(
            name, 
            orig_word,
            wandb.Image(orig_pred_spec_fig, caption=name),
            wandb.Audio(orig_pred_audio, caption=name, sample_rate=sampling_rate),
            wandb.Audio(vocoded_gt_audio, caption=name, sample_rate=sampling_rate),
            respelling,
            wandb.Image(pred_spec_fig, caption=name),
            wandb.Audio(pred_audio, caption=name, sample_rate=sampling_rate),
        )
        
    if train:
        wandb.log({"train_table": table})
    else:
        wandb.log({"val_table": table})
    
    # close figures to save memory
    for fig in orig_pred_spec_figs + pred_spec_figs:
        plt.close(fig)
          
def validate(
    respeller_model, 
    tts_model, 
    vocoder,
    criterion,
    valset, 
    epoch, 
    batch_size, 
    collate_fn, 
    sampling_rate,
    hop_length,
    audio_interval=5,
    n=None, # how many tokens to plot and generate audio for, if None then do the whole first batch
    only_log_table=False,
    train=False,
):
    """Handles all the validation scoring and printing
    GT (beginning of training):
    - log GT mel spec and vocoded audio for several validation set words
    
    Model outputs:
    - log predicted mel spec and vocoded audio from fastpitch
    - log respelled word from respeller
    """
    was_training = respeller_model.training
    respeller_model.eval()
    
    tik = time.perf_counter()
    with torch.no_grad():
        val_loader = DataLoader(valset, num_workers=4, shuffle=False,
                                sampler=None,
                                batch_size=batch_size, pin_memory=False,
                                collate_fn=collate_fn)
        val_meta = defaultdict(float)
        val_losses = 0.0
        epoch_iter = 0
        
        for i, batch in enumerate(val_loader):
            epoch_iter += 1
            
            # get loss over batch
            x, y = batch_to_gpu(batch)
            pred_mel, dec_lens, g_embedding_indices = forward_pass(respeller_model, tts_model, x)
            iter_loss = (criterion(pred_mel, y["mel_padded"]) / dec_lens).mean().item()
            val_losses += iter_loss
    
            # log spectrograms and generated audio for first few utterances
            log_table = (epoch % audio_interval == 0 if epoch is not None else True)
            if (i == 0) and log_table:
                fnames = batch['mel_filepaths']
                bsz = len(fnames)
                if n is None:
                    n = bsz
                
                # get original word and respellings for logging
                original_words = decode_indices(x['text_padded'])
                respellings = decode_indices(g_embedding_indices)
                
                # vocode original recorded speech
                gt_mel = y['mel_padded']
                gt_mel_lens = y['mel_lengths']
                vocoded_gt = generate_audio((gt_mel, gt_mel_lens), fnames, total_iter, vocoder, sampling_rate, hop_length, n=n, label='Predicted audio', mas=True)
                
                # get melspec + generated audio for original spellings
                orig_pred_mel, orig_dec_lens, _dur_pred, _pitch_pred = tts(
                    inputs=x['text_padded'],
                    skip_embeddings=False,
                )
                orig_pred_mel = orig_pred_mel.transpose(1,2)
                _orig_token_names, orig_pred_spec_figs = get_spectrograms_plots((orig_pred_mel, orig_dec_lens), fnames, total_iter, n=n, label='Predicted spectrogram', mas=True)
                orig_pred_audios = generate_audio((orig_pred_mel, orig_dec_lens), fnames, total_iter, vocoder, sampling_rate, hop_length, n=n, label='Predicted audio', mas=True)
            
                # get melspec + generated audio for respellings
                token_names, pred_spec_figs = get_spectrograms_plots((pred_mel, dec_lens), fnames, total_iter, n=n, label='Predicted spectrogram', mas=True)
                pred_audios = generate_audio((pred_mel, dec_lens), fnames, total_iter, vocoder, sampling_rate, hop_length, n=n, label='Predicted audio', mas=True)
                
                # log everything to wandb table
                token_names = [n.split('/')[-1] for n in token_names]
                log_wandb_table(
                    names=token_names,
                    vocoded_gt_audios=vocoded_gt,
                    orig_words=select(original_words, bsz, n=n),
                    orig_pred_spec_figs=orig_pred_spec_figs,
                    orig_pred_audios=orig_pred_audios,
                    respellings=select(respellings, bsz, n=n),
                    pred_spec_figs=pred_spec_figs,
                    pred_audios=pred_audios,
                    sampling_rate=sampling_rate,
                    train=train,
                )
                
            if log_table and only_log_table:
                break # leave for loop after first iteration
        
        if not only_log_table:
            wandb.log({'val/epoch_loss': val_losses/epoch_iter})
    
    if was_training:
        respeller_model.train()

## train loop

In [None]:
def forward_pass(respeller, tts, x):
    """x: inputs
    x = {
        'words': words,
        'text_padded': text_padded,
        'text_lengths': text_lengths,
    }"""
    g_embeddings, g_embedding_indices = respeller(x['text_padded'])
    
    # quantiser_outdict = quantiser(logits, produce_targets=True)
    # g_embedding_indices = quantiser_outdict["targets"].squeeze(2)
    # g_embeddings = quantiser_outdict["x"]

    log_mel, dec_lens, _dur_pred, _pitch_pred = tts(
        inputs=g_embeddings,
        ids=g_embedding_indices,
        skip_embeddings=True,
    )
    
    # log_mel: [bsz, dim, seqlen]
    log_mel = log_mel.transpose(1,2)
    # log_mel: [bsz, seqlen, dim]
    
    # return mask for masking acoustic loss
    # padding_idx = 0
    # mask = (g_embedding_indices != padding_idx).unsqueeze(2)
    # mask.size()
    # dec_mask = mask_from_lens(dec_lens).unsqueeze(2)
    
    return log_mel, dec_lens, g_embedding_indices

In [None]:
for epoch in range(start_epoch, args.epochs + 1):
    # logging metrics
    epoch_start_time = time.perf_counter()
    iter_loss = 0
    epoch_loss = 0.0
    epoch_iter = 0
    num_iters = len(train_loader)
    # epoch_mel_loss = 0.0
    # epoch_num_frames = 0
    # epoch_frames_per_sec = 0.0
    # iter_num_frames = 0
    # iter_meta = {}

    # iterate over all batches in epoch
    for batch in train_loader:        
        # if epoch_iter == 100:
        #     break # NB quit training loop, FOR DEVELOPMENT!!!
        
        if epoch_iter == num_iters: # useful for gradient accumulation
            break
                    
        total_iter += 1
        epoch_iter += 1
        iter_start_time = time.perf_counter()

        adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                             args.warmup_steps)

        optimizer.zero_grad()

        x, y = batch_to_gpu(batch) # x: inputs, y: targets
        gt_mel = y["mel_padded"]
        
        # # y: targets
        # y = {
        #     'mel_padded': mel_padded, 
        #     'mel_lengths': mel_lengths,
        # }
        
        # forward pass through models (respeller -> quantiser -> tts)
        pred_mel, dec_lens, _g_embedding_indices = forward_pass(respeller, tts, x)
        
        # TODO: DO WE NEED MASK IF WE USE SOFTDTW LOSS? 
        # I THINK IT AUTOMATICALLY WILL ALIGN PADDED FRAMES WITH EACH OTHER???
        
        # calculate loss
        loss = criterion(pred_mel, gt_mel)
        # print('raw loss from softdtw', loss)
        
        loss = loss / dec_lens
        # print('loss avg according to dec seqlens', loss)
        
        loss = loss.mean()
        # print('loss avged across batch', loss)
        
        # backpropagation of loss
        loss.backward()
        
        # clip gradients and run optimizer
        torch.nn.utils.clip_grad_norm_(respeller.trainable_parameters(), args.grad_clip_thresh)
        optimizer.step()
        # logger.log_grads_tb(total_iter, model)
        
        # log metrics to terminal and to wandb
        iter_loss = loss.item()
        iter_time = time.perf_counter() - iter_start_time
        epoch_loss += iter_loss
        
        # NB commented out to avoid crashing jupyter notebook
        # log_stdout(
        #     logger,
        #     'train',
        #     (epoch, epoch_iter, num_iters),
        #     total_iter,
        #     iter_loss,
        #     iter_time,
        # )
        
        wandb.log({
            "train/iter_loss": iter_loss,
            "train/iter_time": iter_time,
        })
        ### Finished Epoch!
             
    epoch_time = time.perf_counter() - epoch_start_time
    
    wandb.log({
        "train/epoch_num": epoch,
        "train/epoch_time": epoch_time,
        "train/epoch_loss": epoch_loss / epoch_iter,
    })
    
    # log audio and respellings for training set words
    validate(
        respeller_model=respeller, 
        tts_model=tts, 
        vocoder=vocoder,
        criterion=criterion,
        valset=train_dataset, 
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        epoch=epoch,
        sampling_rate=args.sampling_rate,
        hop_length=args.hop_length,
        audio_interval=args.val_log_interval,
        only_log_table=True,
        train=True,
    )
        
    # log audio and respellings for val set words
    validate(
        respeller_model=respeller, 
        tts_model=tts, 
        vocoder=vocoder,
        criterion=criterion,
        valset=val_dataset, 
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        epoch=epoch,
        sampling_rate=args.sampling_rate,
        hop_length=args.hop_length,
        audio_interval=args.val_log_interval,
    )

    maybe_save_checkpoint(args, respeller, optimizer, 
                          epoch, total_iter, model_config)

    logger.flush()
        
print("\n *** Finished training! ***")

# wandb.finish() # useful in jupyter notebooks