### Only run to clear dataset from colab storage

In [None]:
!rm -rf dataset/

## Package Installation

In [1]:
!pip install torchinfo torch

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collectin

# Variational Auto-Encoder
Based on https://arxiv.org/pdf/2107.03298

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchaudio import datasets, transforms, save
from torchinfo import summary
import os
import collections
from google.colab import drive
import datetime
import wave

## Env Variables
Used for variables that are consistent throughout the whole program for ease of testing and readability.

In [3]:
nhead = 4
dense_net = 1024
epochs = 300
batch_size = 32
mel = transforms.MelSpectrogram(sample_rate=22050, n_fft=1024, win_length=512, hop_length=256, n_mels=80)
inverseMel = transforms.InverseMelScale(sample_rate=22050, n_stft=(1024 // 2) + 1, n_mels=80)
melToWav = transforms.GriffinLim(n_fft=1024, win_length=512, hop_length=256)
loss_weight = [1, 1e-4, 1]
validation = False
retraining = False

## Drive Mounting

Mounting the drive in order to save parameters in the case of a timeout.

In [4]:
drive.mount('/content/drive')

Mounted at /content/drive


## Dataset Loading
Using lj_speech dataset, a public domain speech and text dataset. This dataset will be saved to memory for each session in colab but will need to be redownloaded with each seperate session. If one wants to clear the preexisting dataset if this block times out, please use the utility at the top of this program. IMPORTANT NOTE: please wait for this block to finish running, if it is timed out only a subset of the data will be loaded into memory.

Source: https://keithito.com/LJ-Speech-Dataset/lj_speech

In [5]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {dev}')
path = os.path.join(os.getcwd(), 'dataset/')
os.makedirs(path, exist_ok=True)
rawData = datasets.LJSPEECH(path, download=True)
print(f'Sample Data: {rawData[0]}')
print(f'Length of Dataset: {len(rawData)}')

Device: cuda


100%|██████████| 2.56G/2.56G [00:12<00:00, 221MB/s]


Sample Data: (tensor([[-7.3242e-04, -7.6294e-04, -6.4087e-04,  ...,  7.3242e-04,
          2.1362e-04,  6.1035e-05]]), 22050, 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition', 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition')
Length of Dataset: 13100


## Text Tokenization
Create a dictionary based on the training text.

In [6]:
'''
Create a pair of dictionaries that map text to integers and vice versa.
This is done to facilitate
'''
class Vocabulary():
  def __init__(self):
    self.vocab = []
  '''
  Create a dictionary of characters found in the text based, assign each
  a unique numerical value.

  Return two dictionaries of char->int, int->char
  '''
  def create_dictionary(self, freq):
    self.vocab = list(freq.keys())
    char_to_index = {word: index for index, word in enumerate(self.vocab)}
    index_to_char = {index: word for word, index in char_to_index.items()}

    return (char_to_index, index_to_char)

f = open('dataset/LJSpeech-1.1/metadata.csv')
count = collections.Counter()
for line in f:
  text = line.split('|')[2].lower()
  count.update(list(text))

vocab = Vocabulary()
(char_to_index, index_to_char) = vocab.create_dictionary(count)
print(f'Lengh of vocab: {len(char_to_index)}')
print(list(char_to_index.items()))

Lengh of vocab: 50
[('p', 0), ('r', 1), ('i', 2), ('n', 3), ('t', 4), ('g', 5), (',', 6), (' ', 7), ('h', 8), ('e', 9), ('o', 10), ('l', 11), ('y', 12), ('s', 13), ('w', 14), ('c', 15), ('a', 16), ('d', 17), ('f', 18), ('m', 19), ('x', 20), ('b', 21), ('\n', 22), ('v', 23), ('.', 24), ('u', 25), ('k', 26), ('j', 27), ('"', 28), ('-', 29), (';', 30), ('(', 31), ('z', 32), (')', 33), (':', 34), ("'", 35), ('q', 36), ('!', 37), ('?', 38), ('â', 39), ('é', 40), ('à', 41), ('ê', 42), ('ü', 43), ('è', 44), ('“', 45), ('”', 46), ('’', 47), ('[', 48), (']', 49)]


## Dataset Loader + Collate Func
Needed to pad sequences to the max length, otherwise cannot be used in a dataset loader and would cause batches not to work. This function will take in both the .wav and text, pad the wav before turning it into a spectrogram. It will then tokenize the text using the premade dictionary and then pad it to the max length.

In [7]:
def paddingCollate(batch):
  # ignore other form of text and ID for batching
  waveform, _, text, _ = zip(*batch)
  max_len = max(wave.shape[1] for wave in waveform)
  padded_wave = [torch.nn.functional.pad(wave, (0, max_len - wave.shape[1])) for wave in waveform]
  padded_wave = torch.stack(padded_wave, dim=0)
  # transform to mel spectrogram for ease of use
  spec = [mel(wave).permute(0, 2, 1) for wave in padded_wave]
  padded_spec = torch.nn.utils.rnn.pad_sequence(spec, batch_first=True).squeeze(1)

  tokenized_text = [[char_to_index.get(index, char_to_index[' ']) for index in list(sentence.lower())] for sentence in text]
  tokenized_text = [torch.tensor(sent, dtype=torch.long) for sent in tokenized_text]
  padded_text = torch.nn.utils.rnn.pad_sequence(tokenized_text, batch_first=True)

  return padded_text, padded_spec

# reduce dataset size to 5000
reducedDataSize = list(range(0, 5000))
reducedRawData = torch.utils.data.Subset(rawData, reducedDataSize)

# define train and test split
trainSize = int(0.9 * len(reducedDataSize))
testSize = len(reducedDataSize) - trainSize
train, test = torch.utils.data.random_split(reducedRawData, [trainSize, testSize])

trainDatasetLoader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, collate_fn=paddingCollate)
testDatasetLoader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True, collate_fn=paddingCollate)
print(f'Length of Train: {len(trainDatasetLoader) * batch_size}')
print(f'Length of Test: {len(testDatasetLoader) * batch_size}')

Length of Train: 4512
Length of Test: 512


## Text Encoder
Convolution layers with dropout, batch normalization and ReLU activation. Follow this with positional encoding, and then self attention blocks.

In [8]:
'''
Sinosoidal Positional Encoding
Input: batch x seq len x embed len
Output: batch x seq len x embed len

Applies the sinosoidal positional encoding to the input's embedding. Should be
functional accross all batches.
'''
def sinosoidal_position_encoding(token_size, embedding_dim, batch_size):
  pos = torch.arange(0, token_size, device=dev).unsqueeze(1)
  emb = torch.zeros(token_size, batch_size, embedding_dim, device=dev)

  dividing_term = torch.pow(10000, 2*torch.arange(0, embedding_dim //2, device=dev)/embedding_dim)

  # based on definition in Attention is All You Need
  emb[:, 0, 0::2] = torch.sin(pos/dividing_term)
  emb[:, 0, 1::2] = torch.cos(pos/dividing_term)

  emb = emb.permute(1, 2, 0)

  return emb

'''
Convolutional Layer Stack for Text Encoder
Input: batch x seq len x embed len
Output: batch x seq len x embed len
Convolute each 1D sequence, batch norm, ReLU, and dropout.
'''
class ConvStack(torch.nn.Module):
  def __init__(self, D, K):
    super(ConvStack, self).__init__()
    self.conv = torch.nn.Conv1d(D, D, K, bias=False)
    self.norm = torch.nn.BatchNorm1d(D)
    self.relu = torch.nn.ReLU()
    self.dropout = torch.nn.Dropout1d()
  def forward(self, X):
    convolution = self.conv(X)
    normalization = self.norm(convolution)
    relu = self.relu(normalization)
    output = self.dropout(relu)
    return output

'''
Self Attention Layer
Input: batch x seq len x embed len
Output: batch x seq len x embed len
Can take in either the direct query value or create the values on the fly.
'''
class SelfAttentionLayer(torch.nn.Module):
  def __init__(self, data_dim):
    super(SelfAttentionLayer, self).__init__()
    self.query = nn.Linear(data_dim, data_dim)
    self.value = nn.Linear(data_dim, data_dim)
    self.key = nn.Linear(data_dim, data_dim)

  def forward(self, X):
    Q = self.query(X)
    V = self.value(X)
    K = self.key(X)
    attn = nn.functional.scaled_dot_product_attention(Q, V, K)
    return attn


'''
Text Encoder Block
Input: batch x seq len x em
Output: batch x seq len x embed len

Needs to utilize multiple permutations in order to work within Pytorch's built in encoder,
sinusoidal positional encoding, and self attention.
'''
class TextEncoder(torch.nn.Module):
  def __init__(self, embedding_size, conv_size, K):
    super(TextEncoder, self).__init__()
    self.embedding = torch.nn.Embedding(embedding_size, conv_size)
    self.stack1 = ConvStack(conv_size, K)
    self.stack2 = ConvStack(conv_size, K)
    self.stack3 = ConvStack(conv_size, K)
    self.stack4 = ConvStack(conv_size, K)
    self.stack5 = ConvStack(conv_size, K)
    self.attention1 = SelfAttentionLayer(conv_size)
    self.attention2 = SelfAttentionLayer(conv_size)
    self.attention3 = SelfAttentionLayer(conv_size)
    self.attention4 = SelfAttentionLayer(conv_size)

  def forward(self, X):
    embed = self.embedding(X)
    # permutate to shift order from [batch, seq len, embed len] to [batch, embed len, seq len]
    embed = embed.permute(0, 2, 1)
    conv1 = self.stack1(embed)
    conv2 = self.stack2(conv1)
    conv3 = self.stack3(conv2)
    conv4 = self.stack4(conv3)
    conv5 = self.stack5(conv4)

    # add positional encoding to embedding, no grad
    with torch.no_grad():
      pos = sinosoidal_position_encoding(conv5.size(2), conv5.size(1), conv5.size(0))
    conv_pos = conv5 + pos
    conv_pos = conv_pos.permute(0, 2, 1)
    attn1 = self.attention1(conv_pos)
    attn2 = self.attention2(attn1)
    attn3 = self.attention3(attn2)
    attn4 = self.attention4(attn3)

    return attn4

if validation:
  model = TextEncoder(2000, 256, 5).to(dev)
  input_text = rawData[0][3].lower()
  input_tensor = torch.tensor([char_to_index[item] for item in list(input_text)], device=dev).unsqueeze(0)
  print(summary(model, input_data=input_tensor))

## Posterior Encoder
Fully connected layers w/ dropout and ReLU activation followed by sinusoidal positional encoding.

Attention layers are then stacked upon one another with Q = encoded spectrogram, K, V = encoded text. This should take the form of a self attention layer, a cross-attention layer, and a feed-forward NN with a hidden layer of 1024 and output size of 256.


In [9]:
'''
Self Attention Layer for Posterior Encoder
Input: batch x seq len x embed len
Output: batch x seq len x embed len
Designed to work with self attention between
'''
class PosteriorSelfAttentionLayer(torch.nn.Module):
  def __init__(self, data_dim):
    super(PosteriorSelfAttentionLayer, self).__init__()
    self.query = nn.Linear(data_dim, data_dim)
    self.value = nn.Linear(data_dim, data_dim)
    self.key = nn.Linear(data_dim, data_dim)
    self.d = data_dim
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, Q, V, K):
    # based on definition in Attention is All You Need
    QK = Q @ K
    Y = self.softmax(QK).div(self.d ** 0.5)
    output = (Y @ V.mT)
    return output # permutate to return a form that can be added

'''
Attention Block for Posterior Encoder
Input: batch x seq len x embed len
Output: batch x seq len x embed len
Utilizes a self attention layer, and a cross attention layer with a feed-forward NN.
'''
class AttentionBlock(torch.nn.Module):
  def __init__(self, data_dim, head_num, hidden_dim):
    super(AttentionBlock, self).__init__()
    self.d = data_dim
    self.attn_self = PosteriorSelfAttentionLayer(data_dim)
    self.Q = nn.Linear(data_dim, data_dim)
    self.V = nn.Linear(data_dim, data_dim)
    self.K = nn.Linear(data_dim, data_dim)
    self.attn_cross = nn.MultiheadAttention(data_dim, head_num, batch_first=True)
    self.linear1 = nn.Linear(data_dim, hidden_dim)
    self.relu1 = nn.ReLU()
    self.linear2 = nn.Linear(hidden_dim, 128)

  def forward(self, Q, V, K):
    attn_s = self.attn_self(Q, V, K)
    #attn_c = self.attn_cross(attn_s, V.permute(0, 2, 1), K.permute(0, 2, 1))
    lin1 = self.linear1(attn_s)
    relu = self.relu1(lin1)
    return self.linear2(relu)


'''
Posterior Encoder
Input: batch x seq len x embed len, Spectrogram Y
Output: batch x posterior probability

Returns the posterior probability of the spectrogram given the encoded and transformed text.
'''
class PosteriorEncoder(torch.nn.Module):
  def __init__(self, wav_dim, emb_dim):
    super(PosteriorEncoder, self).__init__()
    self.linear1 = nn.Linear(wav_dim, 256)
    self.relu1 = torch.nn.ReLU()
    self.dropout1 = torch.nn.Dropout()
    self.linear2 = nn.Linear(256, emb_dim)
    self.relu2 = torch.nn.ReLU()
    self.dropout2 = torch.nn.Dropout()
    # in this case, Q is transformed waveform (tgt) and V, K are the encoded text (memory)
    self.attn1 = nn.TransformerDecoderLayer(emb_dim, nhead, dense_net, batch_first=True)
    self.attn2 = nn.TransformerDecoderLayer(emb_dim, nhead, dense_net, batch_first=True)


  def forward(self, spectrogram, text):
    # prenet for spectrogram
    lin1 = self.linear1(spectrogram)
    relu1 = self.relu1(lin1)
    drop1 = self.dropout1(relu1)
    lin2 = self.linear2(drop1)
    relu2 = self.relu2(lin2)
    transformed_spec = self.dropout2(relu2)
    # add sinusoidal positional encoding
    with torch.no_grad():
      pos = sinosoidal_position_encoding(transformed_spec.size(1), transformed_spec.size(2), transformed_spec.size(0)).permute(0, 2, 1)
    pos_spec = transformed_spec + pos
    attn_1 = self.attn1(pos_spec, text)
    attn = self.attn2(attn_1, text)
    return attn

if validation:
  sampleSpec = rawData[0][0]
  model = PosteriorEncoder(80, 256).to(dev)
  input_waveform_tensor = mel(sampleSpec).permute(0, 2, 1).to(dev)
  input_tensor = torch.randint(0, 2000, (1, 131, 256), dtype=torch.float, device=dev)
  print(summary(model, input_data=[input_waveform_tensor, input_tensor]))



## Prior Encoder
Utilizes Glow Blocks, which contain an actnorm layer, an invertible 1x1 convolution layer, and an affine-coupling layer. This implementation is based off of this explanation: https://ameroyer.github.io/reading-notes/generative%20models/2019/05/07/glow_generative_flow_with_invertible_1x1_convolution.html

In [10]:
'''
ActNorm Layer
Similar to batch norm, but parameters aren't determined for each batch but rather
 a set bias and log standard deviation is initalized and trained.
'''
class ActNormLayer(torch.nn.Module):
  def __init__(self, data_dim):
    super(ActNormLayer, self).__init__()
    self.initalized = False
    self.log_weight = nn.Parameter(torch.ones(1, 1, data_dim, device=dev))
    self.bias = nn.Parameter(torch.zeros(1, 1, data_dim, device=dev))

  def forward(self, X):
    if not self.initalized:
      with torch.no_grad(): # don't want to calculate each time and propogated, want mean and std to be set once
        mean = X.mean(dim=(0, 1), keepdim=True)
        std = X.std(dim=(0, 1), keepdim=True)
        #print('in initialization')
        self.mean = mean.to(dev)
        self.log_std = torch.log(std + 0.000001).to(dev)
        #self.bias.copy_(1 / (std + 0.000001)) # find mean and variance along sequence length and batch size
        #self.log_weight.data.copy_(-1 * mean / (std + 0.000001))
        self.initalized = True
    #print('in forward')
    # expand to fill batch size
    log_weight = self.log_weight.expand(X.shape[0], -1, -1)
    bias = self.bias.expand(X.shape[0], -1, -1)
    mean = self.mean.expand(X.shape[0], -1, -1)
    log_std = self.log_std.expand(X.shape[0], -1, -1)

    #torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) # clip gradients in order to prevent exploding gradients in actnorm
    return (X - mean) * torch.exp(log_weight) * torch.exp(-1 * log_std) + bias

'''
Invertible 1x1 convolution.
'''
class Conv1x1(torch.nn.Module):
  def __init__(self, data_dim):
    super(Conv1x1, self).__init__()
    W = torch.randn(data_dim, data_dim)
    P, L, U = torch.linalg.lu(W) # LU decomp as mentioned in paper, enforce invertibility
    self.W = torch.unsqueeze(P @ L @ U, 2).to(dev) # unsqueeze 2nd dim to add a third dim

  def forward(self, X):
    return nn.functional.conv1d(X, self.W, bias=None).permute(0, 2, 1)

class AffineCouplingLayer(torch.nn.Module):
  def __init__(self, data_dim):
    super(AffineCouplingLayer, self).__init__()
    self.split_id = data_dim // 2
    self.text_reduce = nn.Linear(data_dim, data_dim//2)
    self.decoder = nn.TransformerDecoderLayer(data_dim//2, nhead, dense_net, batch_first=True)
    self.s_transform = nn.Linear(data_dim // 2, data_dim // 2)
    self.t_transform = nn.Linear(data_dim // 2, data_dim // 2)

  def forward(self, Z_i, X):
    z_1, z_2 = torch.split(Z_i, self.split_id, dim=2)
    X_reduced = self.text_reduce(X)
    decoded = self.decoder(z_1, X_reduced)
    s = self.s_transform(decoded)
    t = self.t_transform(decoded)
    y_1 = z_1
    y_2 = z_2 * torch.exp(s) + t
    return torch.cat((y_1, y_2), dim=2)

class glowBlock(torch.nn.Module):
  def __init__(self, data_dim):
    super(glowBlock, self).__init__()
    self.act_norm = ActNormLayer(data_dim)
    self.norm = nn.LayerNorm(data_dim)
    self.conv = Conv1x1(data_dim)
    self.coupling = AffineCouplingLayer(data_dim)

  def forward(self, X, Z_i):
    norm = self.act_norm(Z_i)
    #norm = self.norm(Z_i)
    norm = norm.permute(0, 2, 1)
    convoluted = self.conv(norm)
    coupled = self.coupling(convoluted, X)
    return coupled

class PriorEncoder(torch.nn.Module):
  def __init__(self, data_dim):
    super(PriorEncoder, self).__init__()
    self.glow1 = glowBlock(data_dim)
    self.glow2 = glowBlock(data_dim)
    self.glow3 = glowBlock(data_dim)
    self.glow4 = glowBlock(data_dim)
    self.glow5 = glowBlock(data_dim)
    self.glow6 = glowBlock(data_dim)

  def forward(self, X, Z_i):
    g1 = self.glow1(X, Z_i)
    g2 = self.glow2(X, g1)
    g3 = self.glow3(X, g2)
    g4 = self.glow4(X, g3)
    g5 = self.glow5(X, g4)
    g6 = self.glow6(X, g5)
    return g6

if validation:
  Z_i = torch.randn(1, 127, 256).to(dev)
  X = torch.randint(0, 50, (1, 131, 256), dtype=torch.float).to(dev)
  model = PriorEncoder(256).to(dev)
  print(summary(model, input_data=[X, Z_i]))

## Length Predictor
From the encoded text representation, the predicted length of the Mel Spectrogram will be predicted. The loss for this specific module will not be passed back to the text encoder.


In [11]:
class LengthPredictor(torch.nn.Module):
  def __init__(self, text_dim):
    super(LengthPredictor, self).__init__()
    self.linear1 = nn.Linear(text_dim, 256)
    self.relu = nn.ReLU()
    self.linear2 = nn.Linear(256, 1)
    self.softplus = nn.Softplus() # enforce an always positive value for character-utterance length

  def forward(self, X):
    lin1 = self.linear1(X)
    relu = self.relu(lin1)
    lin2 = self.linear2(relu)
    output = self.softplus(lin2)
    sum = torch.sum(output, dim=1)
    return sum

if validation:
  model = LengthPredictor(512)
  input_tensor = torch.zeros([1, 128, 512])
  print(summary(model, input_data=input_tensor))

## Decoder Block

In [12]:
class PostNetBlock(torch.nn.Module):
  def __init__(self, D, K, last_layer=False):
    super(PostNetBlock, self).__init__()
    self.conv = nn.Conv1d(D, D, K, bias=False, padding = (K - 1) // 2)
    self.bn = nn.BatchNorm1d(D)
    self.act = nn.Tanh()
    self.ll = last_layer

  def forward(self, pred_spec):
    c = self.conv(pred_spec)
    batch_norm = self.bn(c)
    activation = self.act(batch_norm)
    if self.ll:
      return batch_norm
    else:
      return activation

class Decoder(torch.nn.Module):
  def __init__(self, emb_dim, K):
    super(Decoder, self).__init__()
    self.deco1 = nn.TransformerDecoderLayer(emb_dim, nhead, dense_net, batch_first=True)
    self.deco2 = nn.TransformerDecoderLayer(emb_dim, nhead, dense_net, batch_first=True)
    self.conv1 = PostNetBlock(emb_dim, K)
    self.conv2 = PostNetBlock(emb_dim, K)
    self.conv3 = PostNetBlock(emb_dim, K)
    self.conv4 = PostNetBlock(emb_dim, K)
    self.conv5 = PostNetBlock(emb_dim, K, last_layer=True)
    self.dim_red = nn.Linear(emb_dim, 80)

  def forward(self, X, Z):
    # in this case, Q is transformed waveform (tgt) and V, K are the encoded text (memory)
    decoder1 = self.deco1(Z, X)
    Y_initial = self.deco2(decoder1, X)
    Y_initial = Y_initial.permute(0, 2, 1)
    c1 = self.conv1(Y_initial)
    c2 = self.conv2(c1)
    c3 = self.conv3(c2)
    c4 = self.conv4(c3)
    c5 = self.conv5(c4)
    complement = torch.add(Y_initial, c5)
    complement = complement.permute(0, 2, 1)
    Y_hat = self.dim_red(complement)
    return Y_hat

if validation:
  Z_i = torch.randn(1, 127, 256)
  X = torch.randint(0, 2000, (1, 131, 256), dtype=torch.float)
  model = Decoder(256, 5)
  print(summary(model, input_data=[X, Z_i]))

## VAENAR Training Model
Construct the training model as defined in the paper

In [13]:
class VARNAR_train(torch.nn.Module):
  def __init__(self, dict_size, data_dim, wave_dim, K):
    super(VARNAR_train, self).__init__()
    self.textEnc = TextEncoder(dict_size, data_dim, K)
    self.posteriorEnc = PosteriorEncoder(wave_dim, data_dim)
    self.lenPred = LengthPredictor(data_dim)
    self.priorEnc = PriorEncoder(data_dim)
    self.decoder = Decoder(data_dim, K)
    self.data_dim = data_dim

  def forward(self, X, Y, length):
    encodedText = self.textEnc(X)
    detachedEncodedText = encodedText.detach()
    pred_length = self.lenPred(detachedEncodedText)
    predictedGaussian = torch.randn(X.shape[0], length, self.data_dim).to(dev)
    posteriorDist = self.posteriorEnc(Y, encodedText)
    priorDist = self.priorEnc(encodedText, predictedGaussian)
    decoded_val = self.decoder(encodedText, posteriorDist)

    return decoded_val, priorDist, posteriorDist, pred_length

if validation:
  model = VARNAR_train(50, 256, 80, 5).to(dev)
  input_text = rawData[0][2].lower()
  sample_X = torch.tensor([char_to_index[item] for item in list(input_text)]).unsqueeze(0).to(dev)
  sample_Y = mel(sampleSpec).permute(0, 2, 1).to(dev)
  print(summary(model, input_data=[sample_X, sample_Y, torch.tensor(sample_Y.size(1)).to(dev)]))

## Training

In [None]:
train_model = VARNAR_train(50, 256, 80, 5)
train_model.to(dev)

if retraining:
  para_pth = '/content/drive/MyDrive/TTS_Parameters/Attempt 2/05-12 22:31.pth'
  params = torch.load(para_pth, weights_only=True)
  train_model.load_state_dict(params)
  kl_weight = 0.1 * ((len(trainDatasetLoader) // batch_size) * 0.001 * epochs)
else:
  kl_weight = 0.1

train_model.train()

optimizer = torch.optim.Adam(train_model.parameters(), lr=1e-4)

len_loss_tracker = []
Z_loss_tracker = []
Y_loss_tracker = []

for i in range(epochs):
  batch_num = 1
  lenLossTotal = 0
  zLossTotal = 0
  yLossTotal = 0
  for x_batch, y_batch in trainDatasetLoader:
    if batch_num == len(trainDatasetLoader) // 4:
      print('1/4 done')
    elif batch_num == len(trainDatasetLoader) // 2:
      print('1/2 done')
    elif batch_num == (len(trainDatasetLoader) // 4) * 3:
      print('3/4 done')
    x_batch = x_batch.to(dev)
    y_batch = y_batch.to(dev)
    length = torch.tensor(y_batch.shape[1], device=dev)
    #print(f'Length shape: {length.shape}')
    y_hat, Z_hat, Z, len_hat = train_model(x_batch, y_batch, length)
    #print(f'y_hat shape: {y_hat.shape}')
    #print(f'Z shape: {Z.shape}')
    #print(f'Z_hat shape: {Z_hat.shape}')
    #print(f'Len_hat shape: {len_hat.shape}')

    lenLossTotal += (len_loss := torch.nn.functional.mse_loss(torch.log(len_hat.squeeze(1)), torch.log(length.repeat(len_hat.shape[0]))))
    zLossTotal += (latent_loss := torch.nn.functional.kl_div(torch.log_softmax(Z_hat + 1e-7, dim=-1), torch.softmax(Z + 1e-7, dim=-1), reduction='batchmean', log_target=False))
    yLossTotal += (spec_loss := torch.nn.functional.mse_loss(y_hat, y_batch))
    total_loss = sum([len_loss * loss_weight[0], latent_loss * loss_weight[1] * kl_weight, spec_loss * loss_weight[2]])

    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    '''
    len_loss.backward(retain_graph=True)
    optimizer.step()
    optimizer.zero_grad()

    otherLoss = sum([latent_loss, spec_loss])
    otherLoss.backward()
    optimizer.step()
    optimizer.zero_grad()
    '''
    kl_weight = min(1.0, kl_weight + 0.001)
    batch_num += 1

  if i % 3 == 0 and i != 0:
    now = datetime.datetime.now()
    torch.save(train_model.state_dict(), f'/content/drive/MyDrive/TTS_Parameters/{now.strftime("%m-%d %H:%M")}.pth')
  len_loss_tracker.append(lenLossTotal)
  Z_loss_tracker.append(zLossTotal)
  Y_loss_tracker.append(yLossTotal)


  print(f'Epoch {i+1} | len loss: {lenLossTotal / len(trainDatasetLoader)} | Z loss: {zLossTotal / len(trainDatasetLoader)} | Y loss: {yLossTotal / len(trainDatasetLoader)}')
  with open('/content/drive/MyDrive/TTS_Parameters/loss_vals.txt', 'a') as f:
    f.write(f'Epoch {i+1}\n')
    f.write(f'len loss: {lenLossTotal / len(trainDatasetLoader)}\n')
    f.write(f'Z loss: {zLossTotal / len(trainDatasetLoader)}\n')
    f.write(f'Y loss: {yLossTotal / len(trainDatasetLoader)}\n\n')




1/4 done
1/2 done
3/4 done
Epoch 1 | len loss: 3.204684019088745 | Z loss: 89047.515625 | Y loss: 1408.248046875
1/4 done
1/2 done
3/4 done
Epoch 2 | len loss: 0.7950267195701599 | Z loss: 51475.75390625 | Y loss: 1332.198486328125
1/4 done
1/2 done
3/4 done
Epoch 3 | len loss: 0.14021655917167664 | Z loss: 38403.953125 | Y loss: 1276.15087890625
1/4 done
1/2 done
3/4 done
Epoch 4 | len loss: 0.043970610946416855 | Z loss: 30126.427734375 | Y loss: 1227.4072265625
1/4 done
1/2 done
3/4 done
Epoch 5 | len loss: 0.025022249668836594 | Z loss: 21078.6953125 | Y loss: 1185.4874267578125
1/4 done
1/2 done
3/4 done
Epoch 6 | len loss: 0.025027254596352577 | Z loss: 7093.0 | Y loss: 1142.99755859375
1/4 done
1/2 done
3/4 done
Epoch 7 | len loss: 0.02961771748960018 | Z loss: 2938.192138671875 | Y loss: 1104.245361328125
1/4 done
1/2 done
3/4 done
Epoch 8 | len loss: 0.04104795306921005 | Z loss: 2209.349365234375 | Y loss: 1069.6773681640625
1/4 done
1/2 done
3/4 done
Epoch 9 | len loss: 0.02

## Testing

In [None]:
train_model.eval()

len_loss_tracker = []
Z_loss_tracker = []
Y_loss_tracker = []

len_accuracy = 0
z_accuracy = 0
y_accuracy = 0
for x_batch, y_batch in testDatasetLoader:
  if batch_num == len(testDatasetLoader) // 4:
    print('1/4 done')
  elif batch_num == len(testDatasetLoader) // 2:
    print('1/2 done')
  elif batch_num == (len(testDatasetLoader) // 4) * 3:
    print('3/4 done')
  x_batch = x_batch.to(dev)
  y_batch = y_batch.to(dev)
  length = torch.tensor(y_batch.shape[1], device=dev)
  #print(f'Length shape: {length.shape}')
  y_hat, Z_hat, Z, len_hat = train_model(x_batch, y_batch, length)
  #print(f'y_hat shape: {y_hat.shape}')
  #print(f'Z shape: {Z.shape}')
  #print(f'Z_hat shape: {Z_hat.shape}')
  #print(f'Len_hat shape: {len_hat.shape}')

  len_accuracy += (len_loss := torch.nn.functional.mse_loss(torch.log(len_hat.squeeze(1)), torch.log(length.repeat(len_hat.shape[0]))))
  z_accuracy += (latent_loss := torch.nn.functional.kl_div(torch.log_softmax(Z_hat + 1e-7, dim=-1), torch.softmax(Z + 1e-7, dim=-1), reduction='batchmean', log_target=False))
  yLossTotal += (spec_loss := torch.nn.functional.mse_loss(y_hat, y_batch))
  y_accuracy = sum([len_loss * loss_weight[0], latent_loss * loss_weight[1] * kl_weight, spec_loss * loss_weight[2]])

len_accuracy /= len(testDatasetLoader)
z_accuracy /= len(testDatasetLoader)
y_accuracy /= len(testDatasetLoader)


## VAENAR Inference Model
Same structure as the other model but different flow of data (using inferenced Z rather than the ground truth one)

In [14]:
class VAENAR_test(torch.nn.Module):
  def __init__(self, dict_size, data_dim, wave_dim, K):
    super(VAENAR_test, self).__init__()
    self.textEnc = TextEncoder(dict_size, data_dim, K)
    self.posteriorEnc = PosteriorEncoder(wave_dim, data_dim)
    self.lenPred = LengthPredictor(data_dim)
    self.priorEnc = PriorEncoder(data_dim)
    self.decoder = Decoder(data_dim, K)
    self.data_dim = data_dim

  def forward(self, X):
    encodedText = self.textEnc(X)
    detachedEncodedText = encodedText.detach()
    pred_length = self.lenPred(detachedEncodedText)
    pred_length = int(pred_length.item())
    predictedGaussian = torch.randn(X.shape[0], pred_length, self.data_dim).to(dev)
    #posteriorDist = self.posteriorEnc(Y, encodedText)
    priorDist = self.priorEnc(encodedText, predictedGaussian)
    decoded_val = self.decoder(encodedText, priorDist)

    return decoded_val

if validation:
  model = VAENAR_test(50, 256, 80, 5).to(dev)
  input_text = rawData[0][2].lower()
  sample_X = torch.tensor([char_to_index[item] for item in list(input_text)]).unsqueeze(0).to(dev)
  sample_Y = mel(sampleSpec).permute(0, 2, 1).to(dev)
  print(summary(model, input_data=[sample_X, sample_Y]))

## Inference
Run if this folder isn't created

In [15]:
!mkdir ModelOutputs

In [19]:
test_model = VAENAR_test(50, 512, 80, 5)
test_model.to(dev)

para_pth = '/content/drive/MyDrive/TTS_Parameters/Attempt 2/05-12 22:22.pth'
testParams = torch.load(para_pth, weights_only=True, map_location=torch.device('cpu'))
test_model.load_state_dict(testParams, strict=False)

test_model.eval()

X = input('Give an input for the model: ')
X_tensor = torch.tensor([char_to_index[item] for item in list(X.lower())]).unsqueeze(0).to(dev)
y_hat = test_model(X_tensor)
invertedSpec = inverseMel(y_hat.permute(0, 2, 1).cpu())
wavFile = melToWav(invertedSpec)

now = datetime.datetime.now()
save(f'/content/ModelOutputs/{now.strftime("%m-%d %H:%M")}.wav', wavFile, 22050)
print('Model has outputted')

Give an input for the model: Printing, in the only sense with which we are at present concerned
Model has outputted
