From 2e5ccc99dee805bb53e1fd960ae1efaa1c12ca7f Mon Sep 17 00:00:00 2001 From: "deepsource-autofix[bot]" <62050782+deepsource-autofix[bot]@users.noreply.github.com> Date: Tue, 25 Aug 2020 05:52:34 +0000 Subject: [PATCH] Format code with black --- audio_processing.py | 15 +- data_utils.py | 70 +++++-- distributed.py | 141 +++++++------ gradient_reversal.py | 18 +- hparams.py | 50 ++--- layers.py | 59 ++++-- logger.py | 32 +-- loss_function.py | 76 ++++--- loss_scaler.py | 37 ++-- model.py | 454 ++++++++++++++++++++++++++++-------------- multiproc.py | 9 +- plotting_utils.py | 37 ++-- residual_encoder.py | 210 +++++++++++-------- speaker_classifier.py | 34 ++-- stft.py | 61 +++--- text/__init__.py | 66 +++--- text/cleaners.py | 99 ++++----- text/cmudict.py | 159 +++++++++++---- text/numbers.py | 92 ++++----- text/symbols.py | 30 +-- train.py | 337 +++++++++++++++++++------------ utils.py | 16 +- 22 files changed, 1297 insertions(+), 805 deletions(-) diff --git a/audio_processing.py b/audio_processing.py index b5af7f7..3a44673 100644 --- a/audio_processing.py +++ b/audio_processing.py @@ -4,8 +4,15 @@ import librosa.util as librosa_util -def window_sumsquare(window, n_frames, hop_length=200, win_length=800, - n_fft=800, dtype=np.float32, norm=None): +def window_sumsquare( + window, + n_frames, + hop_length=200, + win_length=800, + n_fft=800, + dtype=np.float32, + norm=None, +): """ # from librosa 0.6 Compute the sum-square envelope of a window function at a given hop length. @@ -46,13 +53,13 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800, # Compute the squared window at the desired length win_sq = get_window(window, win_length, fftbins=True) - win_sq = librosa_util.normalize(win_sq, norm=norm)**2 + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 win_sq = librosa_util.pad_center(win_sq, n_fft) # Fill the envelope for i in range(n_frames): sample = i * hop_length - x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] return x diff --git a/data_utils.py b/data_utils.py index f8d6be3..f8b028c 100644 --- a/data_utils.py +++ b/data_utils.py @@ -14,6 +14,7 @@ class TextMelLoader(torch.utils.data.Dataset): 2) normalizes text and converts them to sequences of one-hot vectors 3) computes mel-spectrograms from audio files. """ + def __init__(self, audiopaths_and_text, hparams): self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) self.text_cleaners = hparams.text_cleaners @@ -23,26 +24,39 @@ def __init__(self, audiopaths_and_text, hparams): self.audio_dtype = hparams.audio_dtype self.use_librosa = hparams.use_librosa self.stft = layers.TacotronSTFT( - hparams.filter_length, hparams.hop_length, hparams.win_length, - hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, - hparams.mel_fmax) + hparams.filter_length, + hparams.hop_length, + hparams.win_length, + hparams.n_mel_channels, + hparams.sampling_rate, + hparams.mel_fmin, + hparams.mel_fmax, + ) random.seed(hparams.seed) random.shuffle(self.audiopaths_and_text) def get_mel_text_pair(self, audiopath_and_text): # separate filename and text audiopath, text = audiopath_and_text[0], audiopath_and_text[1] - speaker, lang = int(float(audiopath_and_text[2])), int(float(audiopath_and_text[3])) + speaker, lang = ( + int(float(audiopath_and_text[2])), + int(float(audiopath_and_text[3])), + ) text = self.get_text(text) mel = self.get_mel(audiopath) return (text, mel, speaker, lang) def get_mel(self, filename): if not self.load_mel_from_disk: - audio, sampling_rate = load_wav_to_torch(filename, self.use_librosa, self.audio_dtype, self.sampling_rate) + audio, sampling_rate = load_wav_to_torch( + filename, self.use_librosa, self.audio_dtype, self.sampling_rate + ) if sampling_rate != self.stft.sampling_rate: - raise ValueError("{} SR doesn't match target {} SR".format( - sampling_rate, self.stft.sampling_rate)) + raise ValueError( + "{} SR doesn't match target {} SR".format( + sampling_rate, self.stft.sampling_rate + ) + ) audio_norm = audio / self.max_wav_value audio_norm = audio_norm.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) @@ -50,14 +64,16 @@ def get_mel(self, filename): melspec = torch.squeeze(melspec, 0) else: melspec = torch.from_numpy(np.load(filename)) - assert melspec.size(0) == self.stft.n_mel_channels, ( - 'Mel dimension mismatch: given {}, expected {}'.format( - melspec.size(0), self.stft.n_mel_channels)) + assert ( + melspec.size(0) == self.stft.n_mel_channels + ), "Mel dimension mismatch: given {}, expected {}".format( + melspec.size(0), self.stft.n_mel_channels + ) return melspec def get_text(self, text): - text = '*'+text+'`' + text = "*" + text + "`" text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) return text_norm @@ -68,9 +84,10 @@ def __len__(self): return len(self.audiopaths_and_text) -class TextMelCollate(): +class TextMelCollate: """ Zero-pads model inputs and targets based on number of frames per setep """ + def __init__(self, n_frames_per_step): self.n_frames_per_step = n_frames_per_step @@ -80,27 +97,29 @@ def __call__(self, batch): ------ batch: [text_normalized, mel_normalized] """ - + speakers = torch.tensor([batch[i][2] for i in range(len(batch))]) langs = torch.tensor([batch[i][3] for i in range(len(batch))]) - + # Right zero-pad all one-hot text sequences to max input length input_lengths, ids_sorted_decreasing = torch.sort( - torch.LongTensor([len(x[0]) for x in batch]), - dim=0, descending=True) + torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True + ) max_input_len = input_lengths[0] text_padded = torch.LongTensor(len(batch), max_input_len) text_padded.zero_() for i in range(len(ids_sorted_decreasing)): text = batch[ids_sorted_decreasing[i]][0] - text_padded[i, :text.size(0)] = text + text_padded[i, : text.size(0)] = text # Right zero-pad mel-spec num_mels = batch[0][1].size(0) max_target_len = max([x[1].size(1) for x in batch]) if max_target_len % self.n_frames_per_step != 0: - max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step + max_target_len += ( + self.n_frames_per_step - max_target_len % self.n_frames_per_step + ) assert max_target_len % self.n_frames_per_step == 0 # include mel padded and gate padded @@ -111,9 +130,16 @@ def __call__(self, batch): output_lengths = torch.LongTensor(len(batch)) for i in range(len(ids_sorted_decreasing)): mel = batch[ids_sorted_decreasing[i]][1] - mel_padded[i, :, :mel.size(1)] = mel - gate_padded[i, mel.size(1)-1:] = 1 + mel_padded[i, :, : mel.size(1)] = mel + gate_padded[i, mel.size(1) - 1 :] = 1 output_lengths[i] = mel.size(1) - return text_padded, input_lengths, mel_padded, gate_padded, \ - output_lengths, speakers, langs + return ( + text_padded, + input_lengths, + mel_padded, + gate_padded, + output_lengths, + speakers, + langs, + ) diff --git a/distributed.py b/distributed.py index cce7494..fb3c8c9 100644 --- a/distributed.py +++ b/distributed.py @@ -3,6 +3,7 @@ from torch.nn.modules import Module from torch.autograd import Variable + def _flatten_dense_tensors(tensors): """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of same dense type. @@ -19,6 +20,7 @@ def _flatten_dense_tensors(tensors): flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) return flat + def _unflatten_dense_tensors(flat, tensors): """View a flat buffer using the sizes of tensors. Assume that tensors are of same dense type, and that flat is given by _flatten_dense_tensors. @@ -39,7 +41,7 @@ def _unflatten_dense_tensors(flat, tensors): return tuple(outputs) -''' +""" This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py launcher included with this example. It assumes that your run is using multiprocess with 1 GPU/process, that the model is on the correct device, and that torch.set_device has been @@ -47,16 +49,19 @@ def _unflatten_dense_tensors(flat, tensors): Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, and will be allreduced at the finish of the backward pass. -''' -class DistributedDataParallel(Module): +""" + +class DistributedDataParallel(Module): def __init__(self, module): super(DistributedDataParallel, self).__init__() - #fallback for PyTorch 0.3 - if not hasattr(dist, '_backend'): + # fallback for PyTorch 0.3 + if not hasattr(dist, "_backend"): self.warn_on_half = True else: - self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + self.warn_on_half = ( + True if dist._backend == dist.dist_backend.GLOO else False + ) self.module = module @@ -66,7 +71,7 @@ def __init__(self, module): dist.broadcast(p, 0) def allreduce_params(): - if(self.needs_reduction): + if self.needs_reduction: self.needs_reduction = False buckets = {} for param in self.module.parameters(): @@ -77,9 +82,11 @@ def allreduce_params(): buckets[tp].append(param) if self.warn_on_half: if torch.cuda.HalfTensor in buckets: - print("WARNING: gloo dist backend for half parameters may be extremely slow." + - " It is recommended to use the NCCL backend in this case. This currently requires" + - "PyTorch built from top of tree master.") + print( + "WARNING: gloo dist backend for half parameters may be extremely slow." + + " It is recommended to use the NCCL backend in this case. This currently requires" + + "PyTorch built from top of tree master." + ) self.warn_on_half = False for tp in buckets: @@ -88,12 +95,16 @@ def allreduce_params(): coalesced = _flatten_dense_tensors(grads) dist.all_reduce(coalesced) coalesced /= dist.get_world_size() - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + for buf, synced in zip( + grads, _unflatten_dense_tensors(coalesced, grads) + ): buf.copy_(synced) for param in list(self.module.parameters()): + def allreduce_hook(*unused): param._execution_engine.queue_callback(allreduce_params) + if param.requires_grad: param.register_hook(allreduce_hook) @@ -101,7 +112,7 @@ def forward(self, *inputs, **kwargs): self.needs_reduction = True return self.module(*inputs, **kwargs) - ''' + """ def _sync_buffers(self): buffers = list(self.module._all_buffers()) if len(buffers) > 0: @@ -118,56 +129,66 @@ def train(self, mode=True): dist._clear_group_cache() super(DistributedDataParallel, self).train(mode) self.module.train(mode) - ''' -''' -Modifies existing model to do gradient allreduce, but doesn't change class -so you don't need "module" -''' -def apply_gradient_allreduce(module): - if not hasattr(dist, '_backend'): - module.warn_on_half = True - else: - module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False - - for p in module.state_dict().values(): - if not torch.is_tensor(p): - continue - dist.broadcast(p, 0) + """ - def allreduce_params(): - if(module.needs_reduction): - module.needs_reduction = False - buckets = {} - for param in module.parameters(): - if param.requires_grad and param.grad is not None: - tp = param.data.dtype - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(param) - if module.warn_on_half: - if torch.cuda.HalfTensor in buckets: - print("WARNING: gloo dist backend for half parameters may be extremely slow." + - " It is recommended to use the NCCL backend in this case. This currently requires" + - "PyTorch built from top of tree master.") - module.warn_on_half = False - for tp in buckets: - bucket = buckets[tp] - grads = [param.grad.data for param in bucket] - coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced) - coalesced /= dist.get_world_size() - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) +""" +Modifies existing model to do gradient allreduce, but doesn't change class +so you don't need "module" +""" - for param in list(module.parameters()): - def allreduce_hook(*unused): - Variable._execution_engine.queue_callback(allreduce_params) - if param.requires_grad: - param.register_hook(allreduce_hook) - def set_needs_reduction(self, input, output): - self.needs_reduction = True +def apply_gradient_allreduce(module): + if not hasattr(dist, "_backend"): + module.warn_on_half = True + else: + module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + for p in module.state_dict().values(): + if not torch.is_tensor(p): + continue + dist.broadcast(p, 0) + + def allreduce_params(): + if module.needs_reduction: + module.needs_reduction = False + buckets = {} + for param in module.parameters(): + if param.requires_grad and param.grad is not None: + tp = param.data.dtype + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if module.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print( + "WARNING: gloo dist backend for half parameters may be extremely slow." + + " It is recommended to use the NCCL backend in this case. This currently requires" + + "PyTorch built from top of tree master." + ) + module.warn_on_half = False + + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced) + coalesced /= dist.get_world_size() + for buf, synced in zip( + grads, _unflatten_dense_tensors(coalesced, grads) + ): + buf.copy_(synced) + + for param in list(module.parameters()): + + def allreduce_hook(*unused): + Variable._execution_engine.queue_callback(allreduce_params) + + if param.requires_grad: + param.register_hook(allreduce_hook) + + def set_needs_reduction(self, input, output): + self.needs_reduction = True - module.register_forward_hook(set_needs_reduction) - return module + module.register_forward_hook(set_needs_reduction) + return module diff --git a/gradient_reversal.py b/gradient_reversal.py index 6d258ff..963a5d6 100644 --- a/gradient_reversal.py +++ b/gradient_reversal.py @@ -1,16 +1,18 @@ import torch -class reverse_grad(torch.autograd.Function) : - + +class reverse_grad(torch.autograd.Function): @staticmethod - def forward(ctx, x) : + def forward(ctx, x): return x.view_as(x) - + @staticmethod - def backward(ctx, grad_output) : - return - (grad_output.clamp(-0.5,0.5)) + def backward(ctx, grad_output): + return -(grad_output.clamp(-0.5, 0.5)) + rg = reverse_grad.apply -def grad_reverse(x) : - return rg(x) \ No newline at end of file + +def grad_reverse(x): + return rg(x) diff --git a/hparams.py b/hparams.py index 085e4a0..572a136 100644 --- a/hparams.py +++ b/hparams.py @@ -19,18 +19,20 @@ def create_hparams(hparams_string=None, verbose=False): dist_url="tcp://localhost:54091", cudnn_enabled=True, cudnn_benchmark=False, - ignore_layers=['embedding.weight', 'decoder.prenet', 'decoder.linear_projection'], - + ignore_layers=[ + "embedding.weight", + "decoder.prenet", + "decoder.linear_projection", + ], ################################ # Data Parameters # ################################ load_mel_from_disk=False, - audio_dtype = 'np.float32', #Data type of input audio files. If not 'np.int16' ; will be converted to it. - use_librosa = False, #If you want to use librosa for loading file and automatically resampling to sampling_rate - training_files = '../txts/shuffled_train_file.txt', - validation_files='../txts/final_validation.txt', - text_cleaners=['transliteration_cleaners'], - + audio_dtype="np.float32", # Data type of input audio files. If not 'np.int16' ; will be converted to it. + use_librosa=False, # If you want to use librosa for loading file and automatically resampling to sampling_rate + training_files="../txts/shuffled_train_file.txt", + validation_files="../txts/final_validation.txt", + text_cleaners=["transliteration_cleaners"], ################################ # Audio Parameters # ################################ @@ -42,18 +44,15 @@ def create_hparams(hparams_string=None, verbose=False): n_mel_channels=80, mel_fmin=0.0, mel_fmax=8000.0, - ################################ # Model Parameters # ################################ n_symbols=len(symbols), symbols_embedding_dim=512, - # Encoder parameters encoder_kernel_size=5, encoder_n_convolutions=3, encoder_embedding_dim=512, - # Decoder parameters n_frames_per_step=5, # More than 1 is supported now decoder_rnn_dim=1024, @@ -62,58 +61,51 @@ def create_hparams(hparams_string=None, verbose=False): gate_threshold=0.5, p_attention_dropout=0.1, p_decoder_dropout=0.1, - # Attention parameters attention_rnn_dim=1024, attention_dim=128, - # Location Layer parameters attention_location_n_filters=32, attention_location_kernel_size=31, - # Mel-post processing network parameters postnet_embedding_dim=512, postnet_kernel_size=5, postnet_n_convolutions=5, - ################################ # Optimization Hyperparameters # ################################ use_saved_learning_rate=False, learning_rate=1e-4, - anneal = 100, #number of iterations to anneal lr from 0 to 'learning_rate' + anneal=100, # number of iterations to anneal lr from 0 to 'learning_rate' weight_decay=1e-7, grad_clip_thresh=1.0, batch_size=10, mask_padding=True, # set model's padded outputs to padded values - ############################### # Speaker and Lang Embeddings # ############################### - speaker_embedding_dim = 64, - lang_embedding_dim = 3, - n_langs = 2, - n_speakers = 917, - + speaker_embedding_dim=64, + lang_embedding_dim=3, + n_langs=2, + n_speakers=917, ############################### ## Speaker Classifier Params ## ############################### hidden_sc_dim=256, - ############################## ## Residual Encoder Params ## ############################## - residual_encoding_dim = 32, #16 for q(z_l|X) and 16 for q(z_o|X) - dim_yo = 917, #(==n_speakers) dim(y_{o}) - dim_yl = 10, #K - mcn = 2 #n for monte carlo sampling of q(z_l|X)and q(z_o|X) + residual_encoding_dim=32, # 16 for q(z_l|X) and 16 for q(z_o|X) + dim_yo=917, # (==n_speakers) dim(y_{o}) + dim_yl=10, # K + mcn=2, # n for monte carlo sampling of q(z_l|X)and q(z_o|X) ) if hparams_string: - tf.logging.info('Parsing command line hparams: %s', hparams_string) + tf.logging.info("Parsing command line hparams: %s", hparams_string) hparams.parse(hparams_string) if verbose: - tf.logging.info('Final parsed hparams: %s', hparams.values()) + tf.logging.info("Final parsed hparams: %s", hparams.values()) return hparams diff --git a/layers.py b/layers.py index 615a64a..4eb7003 100644 --- a/layers.py +++ b/layers.py @@ -6,33 +6,48 @@ class LinearNorm(torch.nn.Module): - def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): super(LinearNorm, self).__init__() self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) torch.nn.init.xavier_uniform_( - self.linear_layer.weight, - gain=torch.nn.init.calculate_gain(w_init_gain)) + self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) def forward(self, x): return self.linear_layer(x) class ConvNorm(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, - padding=None, dilation=1, bias=True, w_init_gain='linear'): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + w_init_gain="linear", + ): super(ConvNorm, self).__init__() if padding is None: - assert(kernel_size % 2 == 1) + assert kernel_size % 2 == 1 padding = int(dilation * (kernel_size - 1) / 2) - self.conv = torch.nn.Conv1d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, - bias=bias) + self.conv = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) torch.nn.init.xavier_uniform_( - self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) def forward(self, signal): conv_signal = self.conv(signal) @@ -40,17 +55,25 @@ def forward(self, signal): class TacotronSTFT(torch.nn.Module): - def __init__(self, filter_length=1024, hop_length=256, win_length=1024, - n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, - mel_fmax=8000.0): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=80, + sampling_rate=22050, + mel_fmin=0.0, + mel_fmax=8000.0, + ): super(TacotronSTFT, self).__init__() self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate self.stft_fn = STFT(filter_length, hop_length, win_length) mel_basis = librosa_mel_fn( - sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax + ) mel_basis = torch.from_numpy(mel_basis).float() - self.register_buffer('mel_basis', mel_basis) + self.register_buffer("mel_basis", mel_basis) def spectral_normalize(self, magnitudes): output = dynamic_range_compression(magnitudes) @@ -70,8 +93,8 @@ def mel_spectrogram(self, y): ------- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) """ - assert(torch.min(y.data) >= -1) - assert(torch.max(y.data) <= 1) + assert torch.min(y.data) >= -1 + assert torch.max(y.data) <= 1 magnitudes, phases = self.stft_fn.transform(y) magnitudes = magnitudes.data diff --git a/logger.py b/logger.py index ed1977e..1939089 100644 --- a/logger.py +++ b/logger.py @@ -9,12 +9,11 @@ class Tacotron2Logger(SummaryWriter): def __init__(self, logdir): super(Tacotron2Logger, self).__init__(logdir) - def log_training(self, reduced_loss, grad_norm, learning_rate, duration, - iteration): - self.add_scalar("training.loss", reduced_loss, iteration) - self.add_scalar("grad.norm", grad_norm, iteration) - self.add_scalar("learning.rate", learning_rate, iteration) - self.add_scalar("duration", duration, iteration) + def log_training(self, reduced_loss, grad_norm, learning_rate, duration, iteration): + self.add_scalar("training.loss", reduced_loss, iteration) + self.add_scalar("grad.norm", grad_norm, iteration) + self.add_scalar("learning.rate", learning_rate, iteration) + self.add_scalar("duration", duration, iteration) def log_validation(self, reduced_loss, model, y, y_pred, iteration): self.add_scalar("validation.loss", reduced_loss, iteration) @@ -23,7 +22,7 @@ def log_validation(self, reduced_loss, model, y, y_pred, iteration): # plot distribution of parameters for tag, value in model.named_parameters(): - tag = tag.replace('.', '/') + tag = tag.replace(".", "/") self.add_histogram(tag, value.data.cpu().numpy(), iteration) # plot alignment, mel target and predicted, gate target and predicted @@ -31,18 +30,27 @@ def log_validation(self, reduced_loss, model, y, y_pred, iteration): self.add_image( "alignment", plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), - iteration, dataformats='HWC') + iteration, + dataformats="HWC", + ) self.add_image( "mel_target", plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), - iteration, dataformats='HWC') + iteration, + dataformats="HWC", + ) self.add_image( "mel_predicted", plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), - iteration, dataformats='HWC') + iteration, + dataformats="HWC", + ) self.add_image( "gate", plot_gate_outputs_to_numpy( gate_targets[idx].data.cpu().numpy(), - torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()), - iteration, dataformats='HWC') + torch.sigmoid(gate_outputs[idx]).data.cpu().numpy(), + ), + iteration, + dataformats="HWC", + ) diff --git a/loss_function.py b/loss_function.py index ae6b016..a9dba3b 100644 --- a/loss_function.py +++ b/loss_function.py @@ -1,12 +1,11 @@ import torch -from torch import nn -from torch.distributions.kl import kl_divergence as kld +from torch import nn +from torch.distributions.kl import kl_divergence as kld from torch.distributions.categorical import Categorical from torch.distributions.normal import Normal class Tacotron2Loss(nn.Module): - def __init__(self, hparams): super(Tacotron2Loss, self).__init__() self.ce_loss = nn.CrossEntropyLoss() @@ -17,37 +16,58 @@ def forward(self, model_output, targets, re, batched_speakers): gate_target.requires_grad = False gate_target = gate_target.view(-1, 1) - mel_out, mel_out_postnet, gate_out, alignments, spkr_clsfir_logits = model_output + ( + mel_out, + mel_out_postnet, + gate_out, + alignments, + spkr_clsfir_logits, + ) = model_output gate_out = gate_out.view(-1, 1) - mel_loss = nn.MSELoss()(mel_out, mel_target) + \ - nn.MSELoss()(mel_out_postnet, mel_target) + mel_loss = nn.MSELoss()(mel_out, mel_target) + nn.MSELoss()( + mel_out_postnet, mel_target + ) gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) - + means, stddevs = re.q_zo_given_X_at_x.mean, re.q_zo_given_X_at_x.stddev - kl_loss = kld(Normal(means[0], stddevs[0]), re.p_zo_given_yo.distrib_lis[batched_speakers[0]]).mean() - for i, speaker in enumerate(batched_speakers[1:], 1) : - kl_loss += kld(Normal(means[i], stddevs[i]), re.p_zo_given_yo.distrib_lis[speaker]).mean() -# print("speaker kl loss:- ", kl_loss/batched_speakers.shape[0]) - for i in range(re.p_zl_given_yl.n_disc) : - ans = ( re.q_yl_given_X[i]*kld(re.q_zl_given_X_at_x, re.p_zl_given_yl.distrib_lis[i]).mean(dim=1) ).sum() - #print("Difference in i-th dimension of residual :- ", ans, re.q_yl_given_X[i]) + kl_loss = kld( + Normal(means[0], stddevs[0]), + re.p_zo_given_yo.distrib_lis[batched_speakers[0]], + ).mean() + for i, speaker in enumerate(batched_speakers[1:], 1): + kl_loss += kld( + Normal(means[i], stddevs[i]), re.p_zo_given_yo.distrib_lis[speaker] + ).mean() + # print("speaker kl loss:- ", kl_loss/batched_speakers.shape[0]) + for i in range(re.p_zl_given_yl.n_disc): + ans = ( + re.q_yl_given_X[i] + * kld(re.q_zl_given_X_at_x, re.p_zl_given_yl.distrib_lis[i]).mean(dim=1) + ).sum() + # print("Difference in i-th dimension of residual :- ", ans, re.q_yl_given_X[i]) kl_loss += ans -# print("speaker+resdiual kl loss:- ", kl_loss/batched_speakers.shape[0]) - for i in range(re.q_yl_given_X.shape[1]) : - kl_loss += kld( Categorical(re.q_yl_given_X[:,i]), re.y_l) - kl_loss = kl_loss/batched_speakers.shape[0] - - index_into_spkr_logits = batched_speakers.repeat_interleave(spkr_clsfir_logits.shape[1]) - spkr_clsfir_logits = spkr_clsfir_logits.reshape(-1, spkr_clsfir_logits.shape[-1]) - mask_index = spkr_clsfir_logits.abs().sum(dim=1)!=0 + # print("speaker+resdiual kl loss:- ", kl_loss/batched_speakers.shape[0]) + for i in range(re.q_yl_given_X.shape[1]): + kl_loss += kld(Categorical(re.q_yl_given_X[:, i]), re.y_l) + kl_loss = kl_loss / batched_speakers.shape[0] + + index_into_spkr_logits = batched_speakers.repeat_interleave( + spkr_clsfir_logits.shape[1] + ) + spkr_clsfir_logits = spkr_clsfir_logits.reshape( + -1, spkr_clsfir_logits.shape[-1] + ) + mask_index = spkr_clsfir_logits.abs().sum(dim=1) != 0 spkr_clsfir_logits = spkr_clsfir_logits[mask_index] index_into_spkr_logits = index_into_spkr_logits[mask_index] - speaker_loss = self.ce_loss(spkr_clsfir_logits, index_into_spkr_logits) #/batched_speakers.shape[0] - + speaker_loss = self.ce_loss( + spkr_clsfir_logits, index_into_spkr_logits + ) # /batched_speakers.shape[0] + print("Mel Loss:- ", mel_loss) -# print("gate_loss :- ", gate_loss) -# print("speaker_loss :- ", speaker_loss) + # print("gate_loss :- ", gate_loss) + # print("speaker_loss :- ", speaker_loss) print("kl loss:- ", kl_loss) -# print("Total Loss:- ", (mel_loss + gate_loss) + 0.02*speaker_loss +kl_loss) + # print("Total Loss:- ", (mel_loss + gate_loss) + 0.02*speaker_loss +kl_loss) - return (50*mel_loss + gate_loss) + 0.02*speaker_loss +kl_loss + return (50 * mel_loss + gate_loss) + 0.02 * speaker_loss + kl_loss diff --git a/loss_scaler.py b/loss_scaler.py index 0662a60..b414401 100644 --- a/loss_scaler.py +++ b/loss_scaler.py @@ -1,7 +1,7 @@ import torch -class LossScaler: +class LossScaler: def __init__(self, scale=1): self.cur_scale = scale @@ -25,15 +25,12 @@ def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) def backward(self, loss): - scaled_loss = loss*self.loss_scale + scaled_loss = loss * self.loss_scale scaled_loss.backward() -class DynamicLossScaler: - def __init__(self, - init_scale=2**32, - scale_factor=2., - scale_window=1000): +class DynamicLossScaler: + def __init__(self, init_scale=2 ** 32, scale_factor=2.0, scale_window=1000): self.cur_scale = init_scale self.cur_iter = 0 self.last_overflow_iter = -1 @@ -42,7 +39,7 @@ def __init__(self, # `params` is a list / generator of torch.Variable def has_overflow(self, params): -# return False + # return False for p in params: if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): return True @@ -52,20 +49,20 @@ def has_overflow(self, params): # `x` is a torch.Tensor def _has_inf_or_nan(x): cpu_sum = float(x.float().sum()) - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + if cpu_sum == float("inf") or cpu_sum == -float("inf") or cpu_sum != cpu_sum: return True return False # `overflow` is boolean indicating whether we overflowed in gradient def update_scale(self, overflow): if overflow: - #self.cur_scale /= self.scale_factor - self.cur_scale = max(self.cur_scale/self.scale_factor, 1) + # self.cur_scale /= self.scale_factor + self.cur_scale = max(self.cur_scale / self.scale_factor, 1) self.last_overflow_iter = self.cur_iter else: if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: self.cur_scale *= self.scale_factor -# self.cur_scale = 1 + # self.cur_scale = 1 self.cur_iter += 1 @property @@ -76,9 +73,10 @@ def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) def backward(self, loss): - scaled_loss = loss*self.loss_scale + scaled_loss = loss * self.loss_scale scaled_loss.backward() + ############################################################## # Example usage below here -- assuming it's in a separate file ############################################################## @@ -106,9 +104,11 @@ def backward(self, loss): for t in range(500): y_pred = x.mm(w1).clamp(min=0).mm(w2) loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale - print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) - print('Iter {} scaled loss: {}'.format(t, loss.data[0])) - print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + print("Iter {} loss scale: {}".format(t, loss_scaler.loss_scale)) + print("Iter {} scaled loss: {}".format(t, loss.data[0])) + print( + "Iter {} unscaled loss: {}".format(t, loss.data[0] / loss_scaler.loss_scale) + ) # Run backprop optimizer.zero_grad() @@ -120,12 +120,11 @@ def backward(self, loss): # If no overflow, unscale grad and update as usual if not has_overflow: for param in parameters: - param.grad.data.mul_(1. / loss_scaler.loss_scale) + param.grad.data.mul_(1.0 / loss_scaler.loss_scale) optimizer.step() # Otherwise, don't do anything -- ie, skip iteration else: - print('OVERFLOW!') + print("OVERFLOW!") # Update loss scale for next iteration loss_scaler.update_scale(has_overflow) - diff --git a/model.py b/model.py index fe8a6ae..cc86ed3 100644 --- a/model.py +++ b/model.py @@ -9,17 +9,23 @@ from gradient_reversal import grad_reverse from speaker_classifier import speaker_classifier + class LocationLayer(nn.Module): - def __init__(self, attention_n_filters, attention_kernel_size, - attention_dim): + def __init__(self, attention_n_filters, attention_kernel_size, attention_dim): super(LocationLayer, self).__init__() padding = int((attention_kernel_size - 1) / 2) - self.location_conv = ConvNorm(2, attention_n_filters, - kernel_size=attention_kernel_size, - padding=padding, bias=False, stride=1, - dilation=1) - self.location_dense = LinearNorm(attention_n_filters, attention_dim, - bias=False, w_init_gain='tanh') + self.location_conv = ConvNorm( + 2, + attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, + bias=False, + stride=1, + dilation=1, + ) + self.location_dense = LinearNorm( + attention_n_filters, attention_dim, bias=False, w_init_gain="tanh" + ) def forward(self, attention_weights_cat): processed_attention = self.location_conv(attention_weights_cat) @@ -29,21 +35,28 @@ def forward(self, attention_weights_cat): class Attention(nn.Module): - def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, - attention_location_n_filters, attention_location_kernel_size): + def __init__( + self, + attention_rnn_dim, + embedding_dim, + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ): super(Attention, self).__init__() - self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, - bias=False, w_init_gain='tanh') - self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, - w_init_gain='tanh') + self.query_layer = LinearNorm( + attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.memory_layer = LinearNorm( + embedding_dim, attention_dim, bias=False, w_init_gain="tanh" + ) self.v = LinearNorm(attention_dim, 1, bias=False) - self.location_layer = LocationLayer(attention_location_n_filters, - attention_location_kernel_size, - attention_dim) + self.location_layer = LocationLayer( + attention_location_n_filters, attention_location_kernel_size, attention_dim + ) self.score_mask_value = -float("inf") - def get_alignment_energies(self, query, processed_memory, - attention_weights_cat): + def get_alignment_energies(self, query, processed_memory, attention_weights_cat): """ PARAMS ------ @@ -58,14 +71,21 @@ def get_alignment_energies(self, query, processed_memory, processed_query = self.query_layer(query.unsqueeze(1)) processed_attention_weights = self.location_layer(attention_weights_cat) - energies = self.v(torch.tanh( - processed_query + processed_attention_weights + processed_memory)) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + processed_memory) + ) energies = energies.squeeze(-1) return energies - def forward(self, attention_hidden_state, memory, processed_memory, - attention_weights_cat, mask): + def forward( + self, + attention_hidden_state, + memory, + processed_memory, + attention_weights_cat, + mask, + ): """ PARAMS ------ @@ -76,7 +96,8 @@ def forward(self, attention_hidden_state, memory, processed_memory, mask: binary mask for padded data """ alignment = self.get_alignment_energies( - attention_hidden_state, processed_memory, attention_weights_cat) + attention_hidden_state, processed_memory, attention_weights_cat + ) if mask is not None: alignment.data.masked_fill_(mask, self.score_mask_value) @@ -93,8 +114,11 @@ def __init__(self, in_dim, sizes): super(Prenet, self).__init__() in_sizes = [in_dim] + sizes[:-1] self.layers = nn.ModuleList( - [LinearNorm(in_size, out_size, bias=False) - for (in_size, out_size) in zip(in_sizes, sizes)]) + [ + LinearNorm(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_sizes, sizes) + ] + ) def forward(self, x): for linear in self.layers: @@ -113,32 +137,49 @@ def __init__(self, hparams): self.convolutions.append( nn.Sequential( - ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, - kernel_size=hparams.postnet_kernel_size, stride=1, - padding=int((hparams.postnet_kernel_size - 1) / 2), - dilation=1, w_init_gain='tanh'), - nn.BatchNorm1d(hparams.postnet_embedding_dim)) + ConvNorm( + hparams.n_mel_channels, + hparams.postnet_embedding_dim, + kernel_size=hparams.postnet_kernel_size, + stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, + w_init_gain="tanh", + ), + nn.BatchNorm1d(hparams.postnet_embedding_dim), + ) ) for i in range(1, hparams.postnet_n_convolutions - 1): self.convolutions.append( nn.Sequential( - ConvNorm(hparams.postnet_embedding_dim, - hparams.postnet_embedding_dim, - kernel_size=hparams.postnet_kernel_size, stride=1, - padding=int((hparams.postnet_kernel_size - 1) / 2), - dilation=1, w_init_gain='tanh'), - nn.BatchNorm1d(hparams.postnet_embedding_dim)) + ConvNorm( + hparams.postnet_embedding_dim, + hparams.postnet_embedding_dim, + kernel_size=hparams.postnet_kernel_size, + stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, + w_init_gain="tanh", + ), + nn.BatchNorm1d(hparams.postnet_embedding_dim), + ) ) self.convolutions.append( nn.Sequential( - ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, - kernel_size=hparams.postnet_kernel_size, stride=1, - padding=int((hparams.postnet_kernel_size - 1) / 2), - dilation=1, w_init_gain='linear'), - nn.BatchNorm1d(hparams.n_mel_channels)) + ConvNorm( + hparams.postnet_embedding_dim, + hparams.n_mel_channels, + kernel_size=hparams.postnet_kernel_size, + stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, + w_init_gain="linear", + ), + nn.BatchNorm1d(hparams.n_mel_channels), ) + ) def forward(self, x): for i in range(len(self.convolutions) - 1): @@ -153,24 +194,34 @@ class Encoder(nn.Module): - Three 1-d convolution banks - Bidirectional LSTM """ + def __init__(self, hparams): super(Encoder, self).__init__() convolutions = [] for _ in range(hparams.encoder_n_convolutions): conv_layer = nn.Sequential( - ConvNorm(hparams.encoder_embedding_dim, - hparams.encoder_embedding_dim, - kernel_size=hparams.encoder_kernel_size, stride=1, - padding=int((hparams.encoder_kernel_size - 1) / 2), - dilation=1, w_init_gain='relu'), - nn.BatchNorm1d(hparams.encoder_embedding_dim)) + ConvNorm( + hparams.encoder_embedding_dim, + hparams.encoder_embedding_dim, + kernel_size=hparams.encoder_kernel_size, + stride=1, + padding=int((hparams.encoder_kernel_size - 1) / 2), + dilation=1, + w_init_gain="relu", + ), + nn.BatchNorm1d(hparams.encoder_embedding_dim), + ) convolutions.append(conv_layer) self.convolutions = nn.ModuleList(convolutions) - self.lstm = nn.LSTM(hparams.encoder_embedding_dim, - int(hparams.encoder_embedding_dim / 2), 1, - batch_first=True, bidirectional=True) + self.lstm = nn.LSTM( + hparams.encoder_embedding_dim, + int(hparams.encoder_embedding_dim / 2), + 1, + batch_first=True, + bidirectional=True, + ) def forward(self, x, input_lengths): for conv in self.convolutions: @@ -180,14 +231,12 @@ def forward(self, x, input_lengths): # pytorch tensor are not reversible, hence the conversion input_lengths = input_lengths.cpu().numpy() - x = nn.utils.rnn.pack_padded_sequence( - x, input_lengths, batch_first=True) + x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) self.lstm.flatten_parameters() outputs, _ = self.lstm(x) - outputs, _ = nn.utils.rnn.pad_packed_sequence( - outputs, batch_first=True) + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) return outputs @@ -216,43 +265,57 @@ def __init__(self, hparams): self.gate_threshold = hparams.gate_threshold self.p_attention_dropout = hparams.p_attention_dropout self.p_decoder_dropout = hparams.p_decoder_dropout - + self.speaker_embedding_dim = hparams.encoder_embedding_dim self.lang_embedding_dim = hparams.lang_embedding_dim self.residual_encoding_dim = hparams.residual_encoding_dim - + self.mcn = hparams.mcn - + self.prenet = Prenet( - hparams.n_mel_channels * hparams.n_frames_per_step \ - + hparams.lang_embedding_dim + hparams.speaker_embedding_dim \ + hparams.n_mel_channels * hparams.n_frames_per_step + + hparams.lang_embedding_dim + + hparams.speaker_embedding_dim + hparams.residual_encoding_dim, - [hparams.prenet_dim, hparams.prenet_dim]) + [hparams.prenet_dim, hparams.prenet_dim], + ) self.attention_rnn = nn.LSTMCell( hparams.prenet_dim + hparams.encoder_embedding_dim, - hparams.attention_rnn_dim) + hparams.attention_rnn_dim, + ) self.attention_layer = Attention( - hparams.attention_rnn_dim, hparams.encoder_embedding_dim, - hparams.attention_dim, hparams.attention_location_n_filters, - hparams.attention_location_kernel_size) + hparams.attention_rnn_dim, + hparams.encoder_embedding_dim, + hparams.attention_dim, + hparams.attention_location_n_filters, + hparams.attention_location_kernel_size, + ) self.decoder_rnn = nn.LSTMCell( hparams.attention_rnn_dim + hparams.encoder_embedding_dim, - hparams.decoder_rnn_dim, 1) + hparams.decoder_rnn_dim, + 1, + ) self.linear_projection = LinearNorm( hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, - hparams.n_mel_channels * hparams.n_frames_per_step) + hparams.n_mel_channels * hparams.n_frames_per_step, + ) self.gate_layer = LinearNorm( - hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, - bias=True, w_init_gain='sigmoid') - - self.speaker_embeds = nn.Embedding(hparams.n_speakers, hparams.speaker_embedding_dim) - + hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, + 1, + bias=True, + w_init_gain="sigmoid", + ) + + self.speaker_embeds = nn.Embedding( + hparams.n_speakers, hparams.speaker_embedding_dim + ) + self.lang_embeds = nn.Embedding(hparams.n_langs, hparams.lang_embedding_dim) self.residual_encoder = residual_encoders(hparams) @@ -268,8 +331,9 @@ def get_go_frame(self, memory): decoder_input: all zeros frames """ B = memory.size(0) - decoder_input = Variable(memory.data.new( - B, self.n_mel_channels * self.n_frames_per_step).zero_()) + decoder_input = Variable( + memory.data.new(B, self.n_mel_channels * self.n_frames_per_step).zero_() + ) return decoder_input def initialize_decoder_states(self, memory, mask): @@ -284,22 +348,21 @@ def initialize_decoder_states(self, memory, mask): B = memory.size(0) MAX_TIME = memory.size(1) - self.attention_hidden = Variable(memory.data.new( - B, self.attention_rnn_dim).zero_()) - self.attention_cell = Variable(memory.data.new( - B, self.attention_rnn_dim).zero_()) + self.attention_hidden = Variable( + memory.data.new(B, self.attention_rnn_dim).zero_() + ) + self.attention_cell = Variable( + memory.data.new(B, self.attention_rnn_dim).zero_() + ) - self.decoder_hidden = Variable(memory.data.new( - B, self.decoder_rnn_dim).zero_()) - self.decoder_cell = Variable(memory.data.new( - B, self.decoder_rnn_dim).zero_()) + self.decoder_hidden = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_()) + self.decoder_cell = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_()) - self.attention_weights = Variable(memory.data.new( - B, MAX_TIME).zero_()) - self.attention_weights_cum = Variable(memory.data.new( - B, MAX_TIME).zero_()) - self.attention_context = Variable(memory.data.new( - B, self.encoder_embedding_dim).zero_()) + self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_()) + self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_()) + self.attention_context = Variable( + memory.data.new(B, self.encoder_embedding_dim).zero_() + ) self.memory = memory self.processed_memory = self.attention_layer.memory_layer(memory) @@ -320,7 +383,9 @@ def parse_decoder_inputs(self, decoder_inputs): decoder_inputs = decoder_inputs.transpose(1, 2) decoder_inputs = decoder_inputs.reshape( decoder_inputs.size(0), - int(decoder_inputs.size(1)/self.n_frames_per_step), -1) + int(decoder_inputs.size(1) / self.n_frames_per_step), + -1, + ) # (B, T_out/n_frames_ps, n_mel_channels*n_frames_ps) -> (T_out/n_frames_ps, B, n_mel_channels*n_frames_ps) decoder_inputs = decoder_inputs.transpose(0, 1) return decoder_inputs @@ -344,12 +409,11 @@ def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): # (T_out, B) -> (B, T_out) gate_outputs = torch.stack(gate_outputs).transpose(0, 1) gate_outputs = gate_outputs.contiguous() - gate_outputs = gate_outputs.repeat_interleave(self.n_frames_per_step,1) + gate_outputs = gate_outputs.repeat_interleave(self.n_frames_per_step, 1) # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() # decouple frames per step - mel_outputs = mel_outputs.view( - mel_outputs.size(0), -1, self.n_mel_channels) + mel_outputs = mel_outputs.view(mel_outputs.size(0), -1, self.n_mel_channels) # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) mel_outputs = mel_outputs.transpose(1, 2) @@ -369,47 +433,76 @@ def decode(self, decoder_input): """ cell_input = torch.cat((decoder_input, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn( - cell_input, (self.attention_hidden, self.attention_cell)) + cell_input, (self.attention_hidden, self.attention_cell) + ) self.attention_hidden = F.dropout( - self.attention_hidden, self.p_attention_dropout, self.training) + self.attention_hidden, self.p_attention_dropout, self.training + ) attention_weights_cat = torch.cat( - (self.attention_weights.unsqueeze(1), - self.attention_weights_cum.unsqueeze(1)), dim=1) + ( + self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1), + ), + dim=1, + ) self.attention_context, self.attention_weights = self.attention_layer( - self.attention_hidden, self.memory, self.processed_memory, - attention_weights_cat, self.mask) + self.attention_hidden, + self.memory, + self.processed_memory, + attention_weights_cat, + self.mask, + ) self.attention_weights_cum += self.attention_weights - decoder_input = torch.cat( - (self.attention_hidden, self.attention_context), -1) + decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( - decoder_input, (self.decoder_hidden, self.decoder_cell)) + decoder_input, (self.decoder_hidden, self.decoder_cell) + ) self.decoder_hidden = F.dropout( - self.decoder_hidden, self.p_decoder_dropout, self.training) + self.decoder_hidden, self.p_decoder_dropout, self.training + ) decoder_hidden_attention_context = torch.cat( - (self.decoder_hidden, self.attention_context), dim=1) - decoder_output = self.linear_projection( - decoder_hidden_attention_context) + (self.decoder_hidden, self.attention_context), dim=1 + ) + decoder_output = self.linear_projection(decoder_hidden_attention_context) gate_prediction = self.gate_layer(decoder_hidden_attention_context) return decoder_output, gate_prediction, self.attention_weights - def concat_speaker_lang_res_embeds(self, decoder_inputs, speaker, lang, residual_encoding) : - ''' + def concat_speaker_lang_res_embeds( + self, decoder_inputs, speaker, lang, residual_encoding + ): + """ decoder_inputs = [max_audio_len, batch_size, n_mel_filters*n_frames_per_step] residual_encoding = [batch_size*mcn, residual_encoding_dim] speaker = speaker number lang = language number RETURNS: [max_audio_len, batch_size*mcn, n_mel_filters*n_frames_per_step+speaker_embed+lang_embed+residual_embed] - ''' - speaker_embeds, lang_embeds = self.speaker_embeds(speaker), self.lang_embeds(lang) - decoder_inputs = decoder_inputs.transpose(0,1).transpose(1,2).repeat(self.mcn, 1, 1) - speaker_embeds, lang_embeds = speaker_embeds.repeat(self.mcn, 1), lang_embeds.repeat(self.mcn, 1) - to_append = torch.cat([speaker_embeds, lang_embeds, residual_encoding], dim=-1) - to_append = to_append.repeat(decoder_inputs.shape[2],1,1).transpose(0,1).transpose(1,2) - return torch.cat([decoder_inputs, to_append], dim=1).transpose(2,1).transpose(1,0) + """ + speaker_embeds, lang_embeds = ( + self.speaker_embeds(speaker), + self.lang_embeds(lang), + ) + decoder_inputs = ( + decoder_inputs.transpose(0, 1).transpose(1, 2).repeat(self.mcn, 1, 1) + ) + speaker_embeds, lang_embeds = ( + speaker_embeds.repeat(self.mcn, 1), + lang_embeds.repeat(self.mcn, 1), + ) + to_append = torch.cat([speaker_embeds, lang_embeds, residual_encoding], dim=-1) + to_append = ( + to_append.repeat(decoder_inputs.shape[2], 1, 1) + .transpose(0, 1) + .transpose(1, 2) + ) + return ( + torch.cat([decoder_inputs, to_append], dim=1) + .transpose(2, 1) + .transpose(1, 0) + ) def forward(self, memory, decoder_inputs, memory_lengths, speaker, lang): """ Decoder forward pass for training @@ -428,33 +521,42 @@ def forward(self, memory, decoder_inputs, memory_lengths, speaker, lang): decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) - decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) #[seq_len, batch_size, n_mel_channels*n_frames_per_step] - - flattened_decoder_inputs = decoder_inputs.transpose(0,1) - flattened_decoder_inputs = flattened_decoder_inputs.reshape(flattened_decoder_inputs.size(0), -1, int(decoder_inputs.size(2)/self.n_frames_per_step) ).transpose(0,1) + decoder_inputs = torch.cat( + (decoder_input, decoder_inputs), dim=0 + ) # [seq_len, batch_size, n_mel_channels*n_frames_per_step] + + flattened_decoder_inputs = decoder_inputs.transpose(0, 1) + flattened_decoder_inputs = flattened_decoder_inputs.reshape( + flattened_decoder_inputs.size(0), + -1, + int(decoder_inputs.size(2) / self.n_frames_per_step), + ).transpose(0, 1) residual_encoding = self.residual_encoder(flattened_decoder_inputs) - - decoder_inputs = self.concat_speaker_lang_res_embeds(decoder_inputs, speaker, lang, residual_encoding) + + decoder_inputs = self.concat_speaker_lang_res_embeds( + decoder_inputs, speaker, lang, residual_encoding + ) decoder_inputs = self.prenet(decoder_inputs) - memory = memory.repeat(self.mcn,1,1) + memory = memory.repeat(self.mcn, 1, 1) memory_lengths = memory_lengths.repeat(self.mcn) - decoder_input = decoder_input.repeat(1,self.mcn,1) - + decoder_input = decoder_input.repeat(1, self.mcn, 1) + self.initialize_decoder_states( - memory, mask=~get_mask_from_lengths(memory_lengths)) + memory, mask=~get_mask_from_lengths(memory_lengths) + ) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] - mel_output, gate_output, attention_weights = self.decode( - decoder_input) + mel_output, gate_output, attention_weights = self.decode(decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze(1)] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( - mel_outputs, gate_outputs, alignments) + mel_outputs, gate_outputs, alignments + ) return mel_outputs, gate_outputs, alignments @@ -472,18 +574,22 @@ def inference(self, memory, speaker, lang): gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ - self.mcn=1 + self.mcn = 1 decoder_input = self.get_go_frame(memory) self.initialize_decoder_states(memory, mask=None) mel_outputs, gate_outputs, alignments = [], [], [] while True: - - residual_encoding = self.residual_encoder.infer(speaker) #torch.zeros((1,32), device='cuda:0') #batch_sizeXresidual_encoding_dim - + + residual_encoding = self.residual_encoder.infer( + speaker + ) # torch.zeros((1,32), device='cuda:0') #batch_sizeXresidual_encoding_dim + decoder_input = decoder_input.unsqueeze(1) - decoder_input = self.concat_speaker_lang_res_embeds(decoder_input, speaker, lang, residual_encoding).squeeze(1) + decoder_input = self.concat_speaker_lang_res_embeds( + decoder_input, speaker, lang, residual_encoding + ).squeeze(1) decoder_input = self.prenet(decoder_input) mel_output, gate_output, alignment = self.decode(decoder_input) @@ -494,14 +600,15 @@ def inference(self, memory, speaker, lang): if torch.sigmoid(gate_output.data) > self.gate_threshold: break - elif len(mel_outputs)*self.n_frames_per_step >= self.max_decoder_steps: + elif len(mel_outputs) * self.n_frames_per_step >= self.max_decoder_steps: print("Warning! Reached max decoder steps") break decoder_input = mel_output mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( - mel_outputs, gate_outputs, alignments) + mel_outputs, gate_outputs, alignments + ) return mel_outputs, gate_outputs, alignments @@ -513,8 +620,7 @@ def __init__(self, hparams): self.fp16_run = hparams.fp16_run self.n_mel_channels = hparams.n_mel_channels self.n_frames_per_step = hparams.n_frames_per_step - self.embedding = nn.Embedding( - hparams.n_symbols, hparams.symbols_embedding_dim) + self.embedding = nn.Embedding(hparams.n_symbols, hparams.symbols_embedding_dim) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) @@ -525,7 +631,15 @@ def __init__(self, hparams): self.speaker_classifier = speaker_classifier(hparams) def parse_batch(self, batch): - text_padded, input_lengths, mel_padded, gate_padded, output_lengths, speaker, lang = batch + ( + text_padded, + input_lengths, + mel_padded, + gate_padded, + output_lengths, + speaker, + lang, + ) = batch text_padded = to_gpu(text_padded).long() input_lengths = to_gpu(input_lengths).long() max_len = torch.max(input_lengths.data).item() @@ -535,19 +649,39 @@ def parse_batch(self, batch): speaker = to_gpu(speaker).long() lang = to_gpu(lang).long() return ( - (text_padded, input_lengths, mel_padded, max_len, output_lengths, speaker, lang), - (mel_padded.repeat(self.mcn,1,1), gate_padded.repeat(self.mcn,1))) + ( + text_padded, + input_lengths, + mel_padded, + max_len, + output_lengths, + speaker, + lang, + ), + (mel_padded.repeat(self.mcn, 1, 1), gate_padded.repeat(self.mcn, 1)), + ) def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) - if mask.size(2)%self.n_frames_per_step != 0 : - to_append = torch.ones( mask.size(0), mask.size(1), (self.n_frames_per_step-mask.size(2)%self.n_frames_per_step) ).bool().to(mask.device) + if mask.size(2) % self.n_frames_per_step != 0: + to_append = ( + torch.ones( + mask.size(0), + mask.size(1), + ( + self.n_frames_per_step + - mask.size(2) % self.n_frames_per_step + ), + ) + .bool() + .to(mask.device) + ) mask = torch.cat([mask, to_append], dim=-1) mask = mask.permute(1, 0, 2) - mask = mask.repeat(self.mcn,1,1) + mask = mask.repeat(self.mcn, 1, 1) outputs[0].data.masked_fill_(mask, 0.0) outputs[1].data.masked_fill_(mask, 0.0) outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies @@ -561,33 +695,49 @@ def forward(self, inputs): embedded_inputs = self.embedding(text_inputs).transpose(1, 2) encoder_outputs = self.encoder(embedded_inputs, text_lengths) - + encdr_out_for_spkr_clsfir = grad_reverse(encoder_outputs) - spkr_clsfir_logits = self.speaker_classifier(encdr_out_for_spkr_clsfir, text_lengths) + spkr_clsfir_logits = self.speaker_classifier( + encdr_out_for_spkr_clsfir, text_lengths + ) mel_outputs, gate_outputs, alignments = self.decoder( - encoder_outputs, mels, memory_lengths=text_lengths, speaker=speaker, lang=lang) - + encoder_outputs, + mels, + memory_lengths=text_lengths, + speaker=speaker, + lang=lang, + ) + mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet return self.parse_output( - [mel_outputs, mel_outputs_postnet, gate_outputs, alignments, spkr_clsfir_logits], - output_lengths) + [ + mel_outputs, + mel_outputs_postnet, + gate_outputs, + alignments, + spkr_clsfir_logits, + ], + output_lengths, + ) def inference(self, inputs, speaker, language): - self.mcn=1 + self.mcn = 1 embedded_inputs = self.embedding(inputs).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) - + mel_outputs, gate_outputs, alignments = self.decoder.inference( - encoder_outputs, speaker, language) + encoder_outputs, speaker, language + ) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet outputs = self.parse_output( - [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) + [mel_outputs, mel_outputs_postnet, gate_outputs, alignments] + ) return outputs diff --git a/multiproc.py b/multiproc.py index 060ff93..b22d917 100644 --- a/multiproc.py +++ b/multiproc.py @@ -5,17 +5,16 @@ argslist = list(sys.argv)[1:] num_gpus = torch.cuda.device_count() -argslist.append('--n_gpus={}'.format(num_gpus)) +argslist.append("--n_gpus={}".format(num_gpus)) workers = [] job_id = time.strftime("%Y_%m_%d-%H%M%S") argslist.append("--group_name=group_{}".format(job_id)) for i in range(num_gpus): - argslist.append('--rank={}'.format(i)) - stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), - "w") + argslist.append("--rank={}".format(i)) + stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), "w") print(argslist) - p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) + p = subprocess.Popen([str(sys.executable)] + argslist, stdout=stdout) workers.append(p) argslist = argslist[:-1] diff --git a/plotting_utils.py b/plotting_utils.py index ca7e168..8d754ce 100644 --- a/plotting_utils.py +++ b/plotting_utils.py @@ -1,4 +1,5 @@ import matplotlib + matplotlib.use("Agg") import matplotlib.pylab as plt import numpy as np @@ -6,21 +7,20 @@ def save_figure_to_numpy(fig): # save it to a numpy array. - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data def plot_alignment_to_numpy(alignment, info=None): fig, ax = plt.subplots(figsize=(6, 4)) - im = ax.imshow(alignment, aspect='auto', origin='lower', - interpolation='none') + im = ax.imshow(alignment, aspect="auto", origin="lower", interpolation="none") fig.colorbar(im, ax=ax) - xlabel = 'Decoder timestep' + xlabel = "Decoder timestep" if info is not None: - xlabel += '\n\n' + info + xlabel += "\n\n" + info plt.xlabel(xlabel) - plt.ylabel('Encoder timestep') + plt.ylabel("Encoder timestep") plt.tight_layout() fig.canvas.draw() @@ -31,8 +31,7 @@ def plot_alignment_to_numpy(alignment, info=None): def plot_spectrogram_to_numpy(spectrogram): fig, ax = plt.subplots(figsize=(12, 3)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") @@ -46,10 +45,24 @@ def plot_spectrogram_to_numpy(spectrogram): def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): fig, ax = plt.subplots(figsize=(12, 3)) - ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, - color='green', marker='+', s=1, label='target') - ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, - color='red', marker='.', s=1, label='predicted') + ax.scatter( + range(len(gate_targets)), + gate_targets, + alpha=0.5, + color="green", + marker="+", + s=1, + label="target", + ) + ax.scatter( + range(len(gate_outputs)), + gate_outputs, + alpha=0.5, + color="red", + marker=".", + s=1, + label="predicted", + ) plt.xlabel("Frames (Green target, Red predicted)") plt.ylabel("Gate State") diff --git a/residual_encoder.py b/residual_encoder.py index b168bd0..8b4e4d7 100644 --- a/residual_encoder.py +++ b/residual_encoder.py @@ -2,139 +2,177 @@ import torch.nn as nn -class residual_encoder(nn.Module) : - ''' +class residual_encoder(nn.Module): + """ Neural network that can be used to parametrize q(z_{l}|x) and q(z_{o}|x) - ''' + """ + def __init__(self, hparams, log_min_std_dev=-1): super(residual_encoder, self).__init__() self.conv1 = nn.Conv1d(hparams.n_mel_channels, 512, 3, 1) - self.bi_lstm = nn.LSTM(512, 256, 2, bidirectional = True, batch_first=True) + self.bi_lstm = nn.LSTM(512, 256, 2, bidirectional=True, batch_first=True) self.linear = nn.Linear(512, 32) - self.residual_encoding_dim = int(hparams.residual_encoding_dim/2) - self.register_buffer('min_std_dev', torch.exp(torch.tensor([log_min_std_dev]).float()) ) + self.residual_encoding_dim = int(hparams.residual_encoding_dim / 2) + self.register_buffer( + "min_std_dev", torch.exp(torch.tensor([log_min_std_dev]).float()) + ) def forward(self, x): - ''' + """ x.shape = [batch_size, seq_len, n_mel_channels] returns single sample from the distribution q(z_{l}|X) or q(z_{o}|X) of size [batch_size, 16] - ''' - x = self.conv1(x.transpose(2,1)).transpose(2,1) - output, (_,_) = self.bi_lstm(x) + """ + x = self.conv1(x.transpose(2, 1)).transpose(2, 1) + output, (_, _) = self.bi_lstm(x) seq_len = output.shape[1] - output = output.sum(dim=1)/seq_len + output = output.sum(dim=1) / seq_len x = self.linear(output) - mean, log_variance = x[:,:self.residual_encoding_dim], x[:,self.residual_encoding_dim:] + mean, log_variance = ( + x[:, : self.residual_encoding_dim], + x[:, self.residual_encoding_dim :], + ) std_dev = torch.sqrt(torch.exp(log_variance)) - return torch.distributions.normal.Normal(mean,torch.max(std_dev, self.min_std_dev)) #Check here if scale_tril=log_variance ? - -class continuous_given_discrete(nn.Module) : - ''' + return torch.distributions.normal.Normal( + mean, torch.max(std_dev, self.min_std_dev) + ) # Check here if scale_tril=log_variance ? + + +class continuous_given_discrete(nn.Module): + """ Class for p(z_{o}|y_{o}) and p(z_{l}|y_{l}) n_disc :- number of discrete possible values for y_{o/l} distrib_lis[i] :- is the distribution over z , p(z|y=i). Total n_disc distribuitons std_init :- standard deviation is initialized to e^(std_init). And clamped to be >= e^(2*std_init) distribs :- p(z|y) for all y. Can be used to sample n_disc z's {1 from each of the n_disc distribution of prev line}, simultaneously. - ''' - def __init__(self, hparams, n_disc, std_init=-1) : + """ + + def __init__(self, hparams, n_disc, std_init=-1): super(continuous_given_discrete, self).__init__() self.n_disc = n_disc - self.residual_encoding_dim = int(hparams.residual_encoding_dim/2) + self.residual_encoding_dim = int(hparams.residual_encoding_dim / 2) self.std_init = torch.tensor([std_init]).float() - - self.cont_given_disc_mus = nn.Parameter(torch.randn((self.n_disc, self.residual_encoding_dim))) - self.cont_given_disc_sigmas = nn.Parameter(torch.exp(self.std_init)*torch.ones((self.n_disc, self.residual_encoding_dim))) - - self.distrib_lis = self.make_normal_distribs(self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=True) - self.distribs = self.make_normal_distribs(self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=False) - - def make_normal_distribs(self, mus, sigmas, make_lis = False) : - if make_lis : - return [torch.distributions.normal.Normal(mus[i], sigmas[i]) for i in range(mus.shape[0])] + + self.cont_given_disc_mus = nn.Parameter( + torch.randn((self.n_disc, self.residual_encoding_dim)) + ) + self.cont_given_disc_sigmas = nn.Parameter( + torch.exp(self.std_init) + * torch.ones((self.n_disc, self.residual_encoding_dim)) + ) + + self.distrib_lis = self.make_normal_distribs( + self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=True + ) + self.distribs = self.make_normal_distribs( + self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=False + ) + + def make_normal_distribs(self, mus, sigmas, make_lis=False): + if make_lis: + return [ + torch.distributions.normal.Normal(mus[i], sigmas[i]) + for i in range(mus.shape[0]) + ] return torch.distributions.normal.Normal(mus, sigmas) - - def after_optim_step(self) : - sigmas = self.cont_given_disc_sigmas.data - sigmas = sigmas.clamp(float(torch.exp(torch.tensor(2.)*self.std_init).data)) + + def after_optim_step(self): + sigmas = self.cont_given_disc_sigmas.data + sigmas = sigmas.clamp(float(torch.exp(torch.tensor(2.0) * self.std_init).data)) self.cont_given_disc_sigmas.data = sigmas self.cont_given_disc_mus.detach_() self.cont_given_disc_sigmas.detach_() - - self.cont_given_disc_mus.requires_grad=True - self.cont_given_disc_sigmas.requires_grad=True - - self.distrib_lis = self.make_normal_distribs(self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=True) - self.distribs = self.make_normal_distribs(self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=False) - - -class residual_encoders(nn.Module) : - def __init__(self, hparams) : + + self.cont_given_disc_mus.requires_grad = True + self.cont_given_disc_sigmas.requires_grad = True + + self.distrib_lis = self.make_normal_distribs( + self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=True + ) + self.distribs = self.make_normal_distribs( + self.cont_given_disc_mus, self.cont_given_disc_sigmas, make_lis=False + ) + + +class residual_encoders(nn.Module): + def __init__(self, hparams): super(residual_encoders, self).__init__() - - #Variational Posteriors - self.q_zl_given_X = residual_encoder(hparams, -2) #q(z_{l}|X) - self.q_zo_given_X = residual_encoder(hparams, -4) #q(z_{o}|X) + + # Variational Posteriors + self.q_zl_given_X = residual_encoder(hparams, -2) # q(z_{l}|X) + self.q_zo_given_X = residual_encoder(hparams, -4) # q(z_{o}|X) self.q_zl_given_X_at_x = None self.q_zo_given_X_at_x = None - + self.residual_encoding_dim = hparams.residual_encoding_dim self.mcn = hparams.mcn - - #Priors + + # Priors self.y_l_probs = nn.Parameter(torch.ones((hparams.dim_yl))) self.y_l_probs.requires_grad = False self.y_l = torch.distributions.categorical.Categorical(self.y_l_probs) self.p_zo_given_yo = continuous_given_discrete(hparams, hparams.dim_yo, -2) self.p_zl_given_yl = continuous_given_discrete(hparams, hparams.dim_yl, -1) - + self.q_yl_given_X = None - - def calc_q_tilde(self, sampled_zl) : - ''' + + def calc_q_tilde(self, sampled_zl): + """ Caculates approximation to q_yl_given_X using monte carlo sampling, for each element in a batch. Supposed to be recalculated for each batch. - ''' + """ K = self.p_zl_given_yl.n_disc - sampled_zl = sampled_zl.repeat_interleave(K,-2) - sampled_zl = sampled_zl.reshape(sampled_zl.shape[0], -1, K, sampled_zl.shape[-1]) - probs = self.p_zl_given_yl.distribs.log_prob(sampled_zl).exp() #[mcn, batch_size, K, residual_encoding_dim/2] -# print(probs.shape, probs.dtype) - p_zl_givn_yl = probs.double().prod(dim=-1) #[mcn, batch_size, K] - ans = p_zl_givn_yl*self.y_l.probs - #if (ans<1e-12).byte().any() : - # ans = torch.pow(10.0,(torch.log10(torch.min(ans))-12))*ans - normalization_consts = ans.sum(dim=-1) #[mcn, batch_size] -# print(normalization_consts) - ans = ans.permute(2,0,1)/(normalization_consts) #+1e-12) #[K, mcn, batch_size] - self.q_yl_given_X = ans.sum(dim=1)/self.mcn #[K, batch_size] - - def forward(self, x) : - ''' + sampled_zl = sampled_zl.repeat_interleave(K, -2) + sampled_zl = sampled_zl.reshape( + sampled_zl.shape[0], -1, K, sampled_zl.shape[-1] + ) + probs = self.p_zl_given_yl.distribs.log_prob( + sampled_zl + ).exp() # [mcn, batch_size, K, residual_encoding_dim/2] + # print(probs.shape, probs.dtype) + p_zl_givn_yl = probs.double().prod(dim=-1) # [mcn, batch_size, K] + ans = p_zl_givn_yl * self.y_l.probs + # if (ans<1e-12).byte().any() : + # ans = torch.pow(10.0,(torch.log10(torch.min(ans))-12))*ans + normalization_consts = ans.sum(dim=-1) # [mcn, batch_size] + # print(normalization_consts) + ans = ans.permute(2, 0, 1) / ( + normalization_consts + ) # +1e-12) #[K, mcn, batch_size] + self.q_yl_given_X = ans.sum(dim=1) / self.mcn # [K, batch_size] + + def forward(self, x): + """ x.shape = [seq_len, batch_size, n_mel_channels] z_l.shape, z_o.shape == [hparams.mcn, batch_size, hparams.residual_encoding_dim/2] returns concatenation of z_{o} and z_{l} sampled from respective distributions - ''' - x = x.transpose(1,0) - self.q_zl_given_X_at_x, self.q_zo_given_X_at_x = self.q_zl_given_X(x), self.q_zo_given_X(x) - z_l, z_o = self.q_zl_given_X_at_x.rsample((self.mcn, )), self.q_zo_given_X_at_x.rsample((self.mcn,)) #[mcn, batch_size, residual_encoding_dim/2] - self.calc_q_tilde(z_l) - return torch.cat([z_l,z_o], dim=-1).reshape(-1, self.residual_encoding_dim) - - def redefine_y_l(self) : - '''To be called whenever model is sent to new device''' + """ + x = x.transpose(1, 0) + self.q_zl_given_X_at_x, self.q_zo_given_X_at_x = ( + self.q_zl_given_X(x), + self.q_zo_given_X(x), + ) + z_l, z_o = ( + self.q_zl_given_X_at_x.rsample((self.mcn,)), + self.q_zo_given_X_at_x.rsample((self.mcn,)), + ) # [mcn, batch_size, residual_encoding_dim/2] + self.calc_q_tilde(z_l) + return torch.cat([z_l, z_o], dim=-1).reshape(-1, self.residual_encoding_dim) + + def redefine_y_l(self): + """To be called whenever model is sent to new device""" self.y_l = torch.distributions.categorical.Categorical(self.y_l_probs) - def after_optim_step(self) : - ''' + def after_optim_step(self): + """ The parameters :- cont_given_disc_mus, sigmas, are altered, so their distributions need to be made again. - ''' + """ self.p_zo_given_yo.after_optim_step() self.p_zl_given_yl.after_optim_step() - - def infer(self, y_o_idx, y_l_idx=None) : - if y_l_idx is None : + + def infer(self, y_o_idx, y_l_idx=None): + if y_l_idx is None: y_l_idx = self.y_l.sample() z_l = self.p_zl_given_yl.distrib_lis[y_l_idx].sample() z_o = self.p_zo_given_yo.distrib_lis[y_o_idx].sample() - return torch.cat([z_l,z_o], dim=-1).unsqueeze(dim=0) + return torch.cat([z_l, z_o], dim=-1).unsqueeze(dim=0) diff --git a/speaker_classifier.py b/speaker_classifier.py index f294018..04cecef 100644 --- a/speaker_classifier.py +++ b/speaker_classifier.py @@ -1,29 +1,33 @@ import torch.nn as nn import torch + class speaker_classifier(nn.Module): - - def __init__(self, hparams) : + def __init__(self, hparams): super(speaker_classifier, self).__init__() - self.model = nn.Sequential(nn.Linear(hparams.encoder_embedding_dim, hparams.hidden_sc_dim), - nn.Linear(hparams.hidden_sc_dim, hparams.n_speakers)) - - def parse_outputs(self, out, text_lengths) : - mask = torch.arange(out.size(1), device=out.device).expand(out.size(0), out.size(1)) < text_lengths.unsqueeze(1) - out = out.permute(2,0,1) - out = out*mask - out = out.permute(1,2,0) + self.model = nn.Sequential( + nn.Linear(hparams.encoder_embedding_dim, hparams.hidden_sc_dim), + nn.Linear(hparams.hidden_sc_dim, hparams.n_speakers), + ) + + def parse_outputs(self, out, text_lengths): + mask = torch.arange(out.size(1), device=out.device).expand( + out.size(0), out.size(1) + ) < text_lengths.unsqueeze(1) + out = out.permute(2, 0, 1) + out = out * mask + out = out.permute(1, 2, 0) return out - def forward(self, encoder_outputs, text_lengths) : - ''' + def forward(self, encoder_outputs, text_lengths): + """ input :- encoder_outputs = [batch_size, seq_len, encoder_embedding_size] text_lengths = [batch_size] output :- log probabilities of speaker classification = [batch_size, seq_len, n_speakers] - ''' - out = self.model(encoder_outputs) - out = self.parse_outputs( out, text_lengths ) + """ + out = self.model(encoder_outputs) + out = self.parse_outputs(out, text_lengths) return out diff --git a/stft.py b/stft.py index edfc44a..699f1d1 100644 --- a/stft.py +++ b/stft.py @@ -41,8 +41,10 @@ class STFT(torch.nn.Module): """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" - def __init__(self, filter_length=800, hop_length=200, win_length=800, - window='hann'): + + def __init__( + self, filter_length=800, hop_length=200, win_length=800, window="hann" + ): super(STFT, self).__init__() self.filter_length = filter_length self.hop_length = hop_length @@ -53,15 +55,17 @@ def __init__(self, filter_length=800, hop_length=200, win_length=800, fourier_basis = np.fft.fft(np.eye(self.filter_length)) cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), - np.imag(fourier_basis[:cutoff, :])]) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) inverse_basis = torch.FloatTensor( - np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) if window is not None: - assert(filter_length >= win_length) + assert filter_length >= win_length # get window and zero center pad it to filter_length fft_window = get_window(window, win_length, fftbins=True) fft_window = pad_center(fft_window, filter_length) @@ -71,8 +75,8 @@ def __init__(self, filter_length=800, hop_length=200, win_length=800, forward_basis *= fft_window inverse_basis *= fft_window - self.register_buffer('forward_basis', forward_basis.float()) - self.register_buffer('inverse_basis', inverse_basis.float()) + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) def transform(self, input_data): num_batches = input_data.size(0) @@ -85,53 +89,64 @@ def transform(self, input_data): input_data = F.pad( input_data.unsqueeze(1), (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), - mode='reflect') + mode="reflect", + ) input_data = input_data.squeeze(1) forward_transform = F.conv1d( input_data, Variable(self.forward_basis, requires_grad=False), stride=self.hop_length, - padding=0) + padding=0, + ) cutoff = int((self.filter_length / 2) + 1) real_part = forward_transform[:, :cutoff, :] imag_part = forward_transform[:, cutoff:, :] - magnitude = torch.sqrt(real_part**2 + imag_part**2) - phase = torch.autograd.Variable( - torch.atan2(imag_part.data, real_part.data)) + magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) return magnitude, phase def inverse(self, magnitude, phase): recombine_magnitude_phase = torch.cat( - [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) inverse_transform = F.conv_transpose1d( recombine_magnitude_phase, Variable(self.inverse_basis, requires_grad=False), stride=self.hop_length, - padding=0) + padding=0, + ) if self.window is not None: window_sum = window_sumsquare( - self.window, magnitude.size(-1), hop_length=self.hop_length, - win_length=self.win_length, n_fft=self.filter_length, - dtype=np.float32) + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) # remove modulation effects approx_nonzero_indices = torch.from_numpy( - np.where(window_sum > tiny(window_sum))[0]) + np.where(window_sum > tiny(window_sum))[0] + ) window_sum = torch.autograd.Variable( - torch.from_numpy(window_sum), requires_grad=False) + torch.from_numpy(window_sum), requires_grad=False + ) window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum - inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] # scale by hop ratio inverse_transform *= float(self.filter_length) / self.hop_length - inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] - inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] return inverse_transform diff --git a/text/__init__.py b/text/__init__.py index 02ecf0e..2e6362e 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -9,11 +9,11 @@ _id_to_symbol = {i: s for i, s in enumerate(symbols)} # Regular expression matching text enclosed in curly braces: -_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') +_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") def text_to_sequence(text, cleaner_names): - '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." @@ -24,51 +24,51 @@ def text_to_sequence(text, cleaner_names): Returns: List of integers corresponding to the symbols in the text - ''' - sequence = [] + """ + sequence = [] - # Check for curly braces and treat their contents as ARPAbet: - while len(text): - m = _curly_re.match(text) - if not m: - sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) - break - sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) - sequence += _arpabet_to_sequence(m.group(2)) - text = m.group(3) + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) + break + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) - return sequence + return sequence def sequence_to_text(sequence): - '''Converts a sequence of IDs back to a string''' - result = '' - for symbol_id in sequence: - if symbol_id in _id_to_symbol: - s = _id_to_symbol[symbol_id] - # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == '@': - s = '{%s}' % s[1:] - result += s - return result.replace('}{', ' ') + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == "@": + s = "{%s}" % s[1:] + result += s + return result.replace("}{", " ") def _clean_text(text, cleaner_names): - for name in cleaner_names: - cleaner = getattr(cleaners, name) - if not cleaner: - raise Exception('Unknown cleaner: %s' % name) - text = cleaner(text) - return text + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + return text def _symbols_to_sequence(symbols): - return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] def _arpabet_to_sequence(text): - return _symbols_to_sequence(['@' + s for s in text.split()]) + return _symbols_to_sequence(["@" + s for s in text.split()]) def _should_keep_symbol(s): - return s in _symbol_to_id and s is not '_' and s is not '~' + return s in _symbol_to_id and s is not "_" and s is not "~" diff --git a/text/cleaners.py b/text/cleaners.py index 9957a86..34527be 100644 --- a/text/cleaners.py +++ b/text/cleaners.py @@ -1,6 +1,6 @@ """ from https://github.com/keithito/tacotron """ -''' +""" Cleaners are transformations that run over the input text at both training and eval time. Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" @@ -10,84 +10,87 @@ the Unidecode library (https://pypi.python.org/pypi/Unidecode) 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update the symbols in symbols.py to match your data). -''' +""" import re from unidecode import unidecode from .numbers import normalize_numbers from indictrans import Transliterator -trn = Transliterator(source='hin', target='eng', build_lookup=True) +trn = Transliterator(source="hin", target="eng", build_lookup=True) # Regular expression matching whitespace: -_whitespace_re = re.compile(r'\s+') +_whitespace_re = re.compile(r"\s+") # List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text def expand_numbers(text): - return normalize_numbers(text) + return normalize_numbers(text) def lowercase(text): - return text.lower() + return text.lower() def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) + return re.sub(_whitespace_re, " ", text) def convert_to_ascii(text): - return unidecode(text) + return unidecode(text) def basic_cleaners(text): - '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' - text = lowercase(text) - text = collapse_whitespace(text) - return text + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text def transliteration_cleaners(text): - '''Pipeline for non-English text that transliterates to ASCII.''' - text = convert_to_ascii(text) - text = trn.transform(text) - text = lowercase(text) - text = collapse_whitespace(text) - return text + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = trn.transform(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text def english_cleaners(text): - '''Pipeline for English text, including number and abbreviation expansion.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_numbers(text) - text = expand_abbreviations(text) - text = collapse_whitespace(text) - return text + """Pipeline for English text, including number and abbreviation expansion.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/text/cmudict.py b/text/cmudict.py index 62bfef7..f1885ed 100644 --- a/text/cmudict.py +++ b/text/cmudict.py @@ -4,62 +4,137 @@ valid_symbols = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', - 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', - 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', - 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', - 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', - 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', - 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", ] _valid_symbol_set = set(valid_symbols) class CMUDict: - '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' - def __init__(self, file_or_path, keep_ambiguous=True): - if isinstance(file_or_path, str): - with open(file_or_path, encoding='latin-1') as f: - entries = _parse_cmudict(f) - else: - entries = _parse_cmudict(file_or_path) - if not keep_ambiguous: - entries = {word: pron for word, pron in entries.items() if len(pron) == 1} - self._entries = entries - - - def __len__(self): - return len(self._entries) + """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding="latin-1") as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries - def lookup(self, word): - '''Returns list of ARPAbet pronunciations of the given word.''' - return self._entries.get(word.upper()) + def __len__(self): + return len(self._entries) + def lookup(self, word): + """Returns list of ARPAbet pronunciations of the given word.""" + return self._entries.get(word.upper()) -_alt_re = re.compile(r'\([0-9]+\)') +_alt_re = re.compile(r"\([0-9]+\)") def _parse_cmudict(file): - cmudict = {} - for line in file: - if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): - parts = line.split(' ') - word = re.sub(_alt_re, '', parts[0]) - pronunciation = _get_pronunciation(parts[1]) - if pronunciation: - if word in cmudict: - cmudict[word].append(pronunciation) - else: - cmudict[word] = [pronunciation] - return cmudict + cmudict = {} + for line in file: + if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict def _get_pronunciation(s): - parts = s.strip().split(' ') - for part in parts: - if part not in _valid_symbol_set: - return None - return ' '.join(parts) + parts = s.strip().split(" ") + for part in parts: + if part not in _valid_symbol_set: + return None + return " ".join(parts) diff --git a/text/numbers.py b/text/numbers.py index 0d5f7fa..5c30252 100644 --- a/text/numbers.py +++ b/text/numbers.py @@ -5,67 +5,69 @@ _inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') -_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'[0-9]+') +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") def _remove_commas(m): - return m.group(1).replace(',', '') + return m.group(1).replace(",", "") def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') + return m.group(1).replace(".", " point ") def _expand_dollars(m): - match = m.group(1) - parts = match.split('.') - if len(parts) > 2: - return match + ' dollars' # Unexpected format - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) - elif dollars: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - return '%s %s' % (dollars, dollar_unit) - elif cents: - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s' % (cents, cent_unit) - else: - return 'zero dollars' + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) + return _inflect.number_to_words(m.group(0)) def _expand_number(m): - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return 'two thousand' - elif num > 2000 and num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) - elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words( + num, andword="", zero="oh", group=2 + ).replace(", ", " ") else: - return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') - else: - return _inflect.number_to_words(num, andword='') + return _inflect.number_to_words(num, andword="") def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r'\1 pounds', text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/text/symbols.py b/text/symbols.py index 5dc6246..0c30157 100644 --- a/text/symbols.py +++ b/text/symbols.py @@ -1,23 +1,29 @@ """ from https://github.com/keithito/tacotron """ -''' +""" Defines the set of symbols used in text input to the model. -The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' +The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ from text import cmudict -_pad = '_' -_punctuation = '!\'(),.:;? ' -_special = '-' -_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' -start_tok = '*' -end_tok = '`' -#Loop for adding Hindi characters. -#for i in range(2304, 2432) : +_pad = "_" +_punctuation = "!'(),.:;? " +_special = "-" +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +start_tok = "*" +end_tok = "`" +# Loop for adding Hindi characters. +# for i in range(2304, 2432) : # _letters+=chr(i) # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): -_arpabet = ['@' + s for s in cmudict.valid_symbols] +_arpabet = ["@" + s for s in cmudict.valid_symbols] # Export all symbols: -symbols = [_pad, start_tok, end_tok] + list(_special) + list(_punctuation) + list(_letters) + _arpabet +symbols = ( + [_pad, start_tok, end_tok] + + list(_special) + + list(_punctuation) + + list(_letters) + + _arpabet +) diff --git a/train.py b/train.py index 2d2e4ae..9e67fc3 100644 --- a/train.py +++ b/train.py @@ -33,8 +33,12 @@ def init_distributed(hparams, n_gpus, rank, group_name): # Initialize distributed communication dist.init_process_group( - backend=hparams.dist_backend, init_method=hparams.dist_url, - world_size=n_gpus, rank=rank, group_name=group_name) + backend=hparams.dist_backend, + init_method=hparams.dist_url, + world_size=n_gpus, + rank=rank, + group_name=group_name, + ) print("Done initializing distributed") @@ -52,10 +56,16 @@ def prepare_dataloaders(hparams): train_sampler = None shuffle = True - train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle, - sampler=train_sampler, - batch_size=hparams.batch_size, pin_memory=False, - drop_last=True, collate_fn=collate_fn) + train_loader = DataLoader( + trainset, + num_workers=1, + shuffle=shuffle, + sampler=train_sampler, + batch_size=hparams.batch_size, + pin_memory=False, + drop_last=True, + collate_fn=collate_fn, + ) return train_loader, valset, collate_fn @@ -69,18 +79,20 @@ def prepare_directories_and_logger(output_directory, log_directory, rank): logger = None return logger -def anneal_lr(learning_rate, hparams) : - if learning_rate>=hparams.learning_rate : + +def anneal_lr(learning_rate, hparams): + if learning_rate >= hparams.learning_rate: return learning_rate - else : - return learning_rate+(hparams.learning_rate/hparams.anneal) + else: + return learning_rate + (hparams.learning_rate / hparams.anneal) + def load_model(hparams): model = Tacotron2(hparams).cuda() model.decoder.residual_encoder.after_optim_step() model.decoder.residual_encoder.redefine_y_l() if hparams.fp16_run: - model.decoder.attention_layer.score_mask_value = finfo('float16').min + model.decoder.attention_layer.score_mask_value = finfo("float16").min if hparams.distributed_run: model = apply_gradient_allreduce(model) @@ -91,60 +103,83 @@ def load_model(hparams): def warm_start_model(checkpoint_path, model, ignore_layers): assert os.path.isfile(checkpoint_path) print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) - checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') - model_dict = checkpoint_dict['state_dict'] - + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + model_dict = checkpoint_dict["state_dict"] + if len(ignore_layers) > 0: common_parts_dict = {} - for k , v in model_dict.items() : + for k, v in model_dict.items(): should_ignore = False - for elem in ignore_layers : - if k.count(elem)>0: - should_ignore=True + for elem in ignore_layers: + if k.count(elem) > 0: + should_ignore = True break - if not should_ignore : + if not should_ignore: common_parts_dict[k] = v - + dummy_dict = model.state_dict() dummy_dict.update(common_parts_dict) model_dict = dummy_dict - + model.load_state_dict(model_dict) - + return model def load_checkpoint(checkpoint_path, model, optimizer): assert os.path.isfile(checkpoint_path) print("Loading checkpoint '{}'".format(checkpoint_path)) - checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') - model.load_state_dict(checkpoint_dict['state_dict']) - optimizer.load_state_dict(checkpoint_dict['optimizer']) - learning_rate = checkpoint_dict['learning_rate'] - iteration = checkpoint_dict['iteration'] - print("Loaded checkpoint '{}' from iteration {}" .format( - checkpoint_path, iteration)) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + model.load_state_dict(checkpoint_dict["state_dict"]) + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + learning_rate = checkpoint_dict["learning_rate"] + iteration = checkpoint_dict["iteration"] + print("Loaded checkpoint '{}' from iteration {}".format(checkpoint_path, iteration)) return model, optimizer, learning_rate, iteration def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): - print("Saving model and optimizer state at iteration {} to {}".format( - iteration, filepath)) - torch.save({'iteration': iteration, - 'state_dict': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'learning_rate': learning_rate}, filepath) - - -def validate(model, criterion, valset, iteration, batch_size, n_gpus, - collate_fn, logger, distributed_run, rank): + print( + "Saving model and optimizer state at iteration {} to {}".format( + iteration, filepath + ) + ) + torch.save( + { + "iteration": iteration, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + filepath, + ) + + +def validate( + model, + criterion, + valset, + iteration, + batch_size, + n_gpus, + collate_fn, + logger, + distributed_run, + rank, +): """Handles all the validation scoring and printing""" model.eval() with torch.no_grad(): val_sampler = DistributedSampler(valset) if distributed_run else None - val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, - shuffle=False, batch_size=batch_size, - pin_memory=False, collate_fn=collate_fn) + val_loader = DataLoader( + valset, + sampler=val_sampler, + num_workers=1, + shuffle=False, + batch_size=batch_size, + pin_memory=False, + collate_fn=collate_fn, + ) val_loss = 0.0 for i, batch in enumerate(val_loader): @@ -164,8 +199,16 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, logger.log_validation(val_loss, model, y, y_pred, iteration) -def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, - rank, group_name, hparams): +def train( + output_directory, + log_directory, + checkpoint_path, + warm_start, + n_gpus, + rank, + group_name, + hparams, +): """Training and validation logging results to tensorboard and stdout Params @@ -185,21 +228,21 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, model = load_model(hparams) learning_rate = 0 if hparams.anneal else hparams.learning_rate - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, - weight_decay=hparams.weight_decay) + optimizer = torch.optim.Adam( + model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay + ) if hparams.fp16_run: from apex import amp - model, optimizer = amp.initialize( - model, optimizer, opt_level='O2') + + model, optimizer = amp.initialize(model, optimizer, opt_level="O2") if hparams.distributed_run: model = apply_gradient_allreduce(model) criterion = Tacotron2Loss(hparams).cuda() - logger = prepare_directories_and_logger( - output_directory, log_directory, rank) + logger = prepare_directories_and_logger(output_directory, log_directory, rank) train_loader, valset, collate_fn = prepare_dataloaders(hparams) @@ -208,11 +251,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, epoch_offset = 0 if checkpoint_path is not None: if warm_start: - model = warm_start_model( - checkpoint_path, model, hparams.ignore_layers) + model = warm_start_model(checkpoint_path, model, hparams.ignore_layers) else: model, optimizer, _learning_rate, iteration = load_checkpoint( - checkpoint_path, model, optimizer) + checkpoint_path, model, optimizer + ) if hparams.use_saved_learning_rate: learning_rate = _learning_rate iteration += 1 # next iteration is iteration + 1 @@ -224,77 +267,113 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, for epoch in range(epoch_offset, hparams.epochs): print("Epoch: {}".format(epoch)) for i, batch in enumerate(train_loader): - start = time.perf_counter() -# with torch.autograd.detect_anomaly() : - for param_group in optimizer.param_groups: - param_group['lr'] = learning_rate -# print("Learning Rate= ", learning_rate) - - model.zero_grad() - x, y = model.parse_batch(batch) - y_pred = model(x) - - loss = criterion(y_pred, y, model.decoder.residual_encoder, x[5]) - if hparams.distributed_run: - reduced_loss = reduce_tensor(loss.data, n_gpus).item() - else: - reduced_loss = loss.item() - if hparams.fp16_run: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() - - if hparams.fp16_run: - grad_norm = torch.nn.utils.clip_grad_norm_( - amp.master_params(optimizer), hparams.grad_clip_thresh) - is_overflow = math.isnan(grad_norm) - else: - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), hparams.grad_clip_thresh) - - optimizer.step() - learning_rate = anneal_lr(learning_rate, hparams) - model.decoder.residual_encoder.after_optim_step() - - if not is_overflow and rank == 0: - duration = time.perf_counter() - start - print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( - iteration, reduced_loss, grad_norm, duration)) - logger.log_training( - reduced_loss, grad_norm, learning_rate, duration, iteration) - - if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0): - validate(model, criterion, valset, iteration, - hparams.batch_size, n_gpus, collate_fn, logger, - hparams.distributed_run, rank) - if rank == 0: - checkpoint_path = os.path.join( - output_directory, "checkpoint_{}".format(iteration)) - save_checkpoint(model, optimizer, learning_rate, iteration, - checkpoint_path) - - iteration += 1 - - -if __name__ == '__main__': + start = time.perf_counter() + # with torch.autograd.detect_anomaly() : + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate + # print("Learning Rate= ", learning_rate) + + model.zero_grad() + x, y = model.parse_batch(batch) + y_pred = model(x) + + loss = criterion(y_pred, y, model.decoder.residual_encoder, x[5]) + if hparams.distributed_run: + reduced_loss = reduce_tensor(loss.data, n_gpus).item() + else: + reduced_loss = loss.item() + if hparams.fp16_run: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + if hparams.fp16_run: + grad_norm = torch.nn.utils.clip_grad_norm_( + amp.master_params(optimizer), hparams.grad_clip_thresh + ) + is_overflow = math.isnan(grad_norm) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), hparams.grad_clip_thresh + ) + + optimizer.step() + learning_rate = anneal_lr(learning_rate, hparams) + model.decoder.residual_encoder.after_optim_step() + + if not is_overflow and rank == 0: + duration = time.perf_counter() - start + print( + "Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( + iteration, reduced_loss, grad_norm, duration + ) + ) + logger.log_training( + reduced_loss, grad_norm, learning_rate, duration, iteration + ) + + if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0): + validate( + model, + criterion, + valset, + iteration, + hparams.batch_size, + n_gpus, + collate_fn, + logger, + hparams.distributed_run, + rank, + ) + if rank == 0: + checkpoint_path = os.path.join( + output_directory, "checkpoint_{}".format(iteration) + ) + save_checkpoint( + model, optimizer, learning_rate, iteration, checkpoint_path + ) + + iteration += 1 + + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-o', '--output_directory', type=str, - help='directory to save checkpoints') - parser.add_argument('-l', '--log_directory', type=str, - help='directory to save tensorboard logs') - parser.add_argument('-c', '--checkpoint_path', type=str, default=None, - required=False, help='checkpoint path') - parser.add_argument('--warm_start', action='store_true', - help='load model weights only, ignore specified layers') - parser.add_argument('--n_gpus', type=int, default=1, - required=False, help='number of gpus') - parser.add_argument('--rank', type=int, default=0, - required=False, help='rank of current gpu') - parser.add_argument('--group_name', type=str, default='group_name', - required=False, help='Distributed group name') - parser.add_argument('--hparams', type=str, - required=False, help='comma separated name=value pairs') + parser.add_argument( + "-o", "--output_directory", type=str, help="directory to save checkpoints" + ) + parser.add_argument( + "-l", "--log_directory", type=str, help="directory to save tensorboard logs" + ) + parser.add_argument( + "-c", + "--checkpoint_path", + type=str, + default=None, + required=False, + help="checkpoint path", + ) + parser.add_argument( + "--warm_start", + action="store_true", + help="load model weights only, ignore specified layers", + ) + parser.add_argument( + "--n_gpus", type=int, default=1, required=False, help="number of gpus" + ) + parser.add_argument( + "--rank", type=int, default=0, required=False, help="rank of current gpu" + ) + parser.add_argument( + "--group_name", + type=str, + default="group_name", + required=False, + help="Distributed group name", + ) + parser.add_argument( + "--hparams", type=str, required=False, help="comma separated name=value pairs" + ) args = parser.parse_args() hparams = create_hparams(args.hparams) @@ -308,5 +387,13 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, print("cuDNN Enabled:", hparams.cudnn_enabled) print("cuDNN Benchmark:", hparams.cudnn_benchmark) - train(args.output_directory, args.log_directory, args.checkpoint_path, - args.warm_start, args.n_gpus, args.rank, args.group_name, hparams) + train( + args.output_directory, + args.log_directory, + args.checkpoint_path, + args.warm_start, + args.n_gpus, + args.rank, + args.group_name, + hparams, + ) diff --git a/utils.py b/utils.py index 4fe4624..4200e46 100644 --- a/utils.py +++ b/utils.py @@ -12,21 +12,23 @@ def get_mask_from_lengths(lengths): return mask -def load_wav_to_torch(full_path, use_librosa=False, audio_dtype='np.int16', final_sr=22050): - if audio_dtype!='np.int16' : - audio, sampling_rate = sf.read(full_path, dtype='int16') +def load_wav_to_torch( + full_path, use_librosa=False, audio_dtype="np.int16", final_sr=22050 +): + if audio_dtype != "np.int16": + audio, sampling_rate = sf.read(full_path, dtype="int16") audio = librosa.resample(audio.astype(np.float32), sampling_rate, final_sr) data = audio.astype(np.int16) - else : - if use_librosa : + else: + if use_librosa: data, final_sr = librosa.load(full_path, sr=final_sr) - else : + else: final_sr, data = read(full_path) return torch.FloatTensor(data.astype(np.float32)), final_sr def load_filepaths_and_text(filename, split="|"): - with open(filename, encoding='utf-8') as f: + with open(filename, encoding="utf-8") as f: filepaths_and_text = [line.strip().split(split) for line in f] return filepaths_and_text