In [12]:
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import params
from model import GradTTS
from data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate
from utils import plot_tensor, save_plot
from text.symbols import symbols
from model.utils import sequence_mask

In [2]:
train_filelist_path = params.train_filelist_path
valid_filelist_path = params.valid_filelist_path

cmudict_path = params.cmudict_path
add_blank = params.add_blank
n_spks = params.n_spks
spk_emb_dim = params.spk_emb_dim

log_dir = params.log_dir
n_epochs = params.n_epochs
batch_size = params.batch_size
out_size = params.out_size
learning_rate = params.learning_rate
random_seed = params.seed

nsymbols = len(symbols) + 1 if add_blank else len(symbols)
n_enc_channels = params.n_enc_channels
filter_channels = params.filter_channels
filter_channels_dp = params.filter_channels_dp
n_enc_layers = params.n_enc_layers
enc_kernel = params.enc_kernel
enc_dropout = params.enc_dropout
n_heads = params.n_heads
window_size = params.window_size

n_feats = params.n_feats
n_fft = params.n_fft
sample_rate = params.sample_rate
hop_length = params.hop_length
win_length = params.win_length
f_min = params.f_min
f_max = params.f_max

dec_dim = params.dec_dim
beta_min = params.beta_min
beta_max = params.beta_max
pe_scale = params.pe_scale


In [3]:
train_dataset = TextMelSpeakerDataset(train_filelist_path, cmudict_path, add_blank,
                                          n_fft, n_feats, sample_rate, hop_length,
                                          win_length, f_min, f_max)
batch_collate = TextMelSpeakerBatchCollate()
loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                    collate_fn=batch_collate, drop_last=True,
                    num_workers=8, shuffle=True)
test_dataset = TextMelSpeakerDataset(valid_filelist_path, cmudict_path, add_blank,
                                        n_fft, n_feats, sample_rate, hop_length,
                                        win_length, f_min, f_max)

In [25]:
batch = next(iter(loader))

  normalized, onesided, return_complex)
  normalized, onesided, return_complex)
  normalized, onesided, return_complex)
  normalized, onesided, return_complex)
  normalized, onesided, return_complex)
  normalized, onesided, return_complex)
  normalized, onesided, return_complex)
  normalized, onesided, return_complex)


In [26]:
batch.keys()

dict_keys(['x', 'x_lengths', 'y', 'y_lengths', 'speaker', 'spk', 'emo'])

In [27]:
x, x_lengths = batch['x'], batch['x_lengths']
y, y_lengths = batch['y'], batch['y_lengths']
speaker, spk, emo = batch['speaker'], batch['spk'], batch['emo']

In [28]:
emo.shape

torch.Size([16, 768])

In [29]:
spk_emb = torch.nn.Linear(256, spk_emb_dim)
emo_emb = torch.nn.Linear(768, spk_emb_dim)

In [30]:
spk = spk_emb(spk)
emo = emo_emb(emo)

In [31]:
spk.shape, emo.shape

(torch.Size([16, 32]), torch.Size([16, 32]))

In [32]:
emb = torch.nn.Embedding(nsymbols, n_enc_channels)

In [33]:
import math
x = emb(x) * math.sqrt(n_enc_channels)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)

In [34]:
x_mask.shape

torch.Size([16, 1, 101])

In [35]:
from model.text_encoder import ConvReluNorm, Encoder

In [36]:
prenet = ConvReluNorm(n_enc_channels, n_enc_channels, n_enc_channels, 
                        kernel_size=5, n_layers=3, p_dropout=0.5)

x = prenet(x, x_mask)

In [37]:
x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
x = torch.cat([x, emo.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)

In [38]:
x.shape

torch.Size([16, 256, 101])

In [39]:
encoder = Encoder(n_enc_channels + spk_emb_dim * 2, filter_channels, n_heads, n_enc_layers, 
                    enc_kernel, enc_dropout, window_size=window_size)

In [40]:
x = encoder(x, x_mask)

In [41]:
x.shape

torch.Size([16, 256, 101])