Skip to content

Commit

Permalink
Radtts 1.13 plus (NVIDIA#5457) (NVIDIA#5471)
Browse files Browse the repository at this point in the history
* [TTS] Fixing RADTTS training - removing view buffer and fixing accuracy issue (NVIDIA#5358)
* Fixing RADTTS training - removing view buffer and fixing accuracy issue
* Fixes for Torchscript/Triton
* Added autocast to radtts UT
* using cuda() for training example

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Boris Fomitchev <borisfom@users.noreply.github.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
4 people authored and Hainan Xu committed Nov 29, 2022
1 parent dfbe987 commit 08bc700
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 55 deletions.
4 changes: 2 additions & 2 deletions examples/tts/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def prepare_model_weights(model, unfreeze_modules):
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get('exp_manager', None))
model = RadTTSModel(cfg=cfg.model, trainer=trainer)
model = RadTTSModel(cfg=cfg.model, trainer=trainer).cuda()
if cfg.model.load_from_checkpoint:
model.maybe_init_from_pretrained_checkpoint(cfg=cfg.model)
prepare_model_weights(model, cfg.model.trainerConfig.unfreeze_modules)
lr_logger = pl.callbacks.LearningRateMonitor()
epoch_time_logger = LogEpochTimeCallback()
trainer.callbacks.extend([lr_logger, epoch_time_logger])
trainer.fit(model)
trainer.fit(model.cuda())


if __name__ == '__main__':
Expand Down
50 changes: 15 additions & 35 deletions nemo/collections/tts/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,39 +118,29 @@ def __init__(self, input_size, hidden_size, num_layers=1, lstm_norm_fn="spectral
lstm_norm_fn_pntr(self.bilstm, 'weight_hh_l0_reverse')
self.bilstm.flatten_parameters()

@torch.jit.export
def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> Tuple[Tensor, Tensor]:
seq = nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
else:
ret, _ = self.bilstm.forward_1(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)
return self.lstm_sequence(seq)

@torch.jit.export
def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]:
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
elif hasattr(self.bilstm, 'forward_1'):
ret, _ = self.bilstm.forward_1(seq)
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

@torch.jit.export
def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
def forward(self, context: Tensor, lens: Tensor) -> Tensor:
context, lens_sorted, unsort_ids = sort_tensor(context, lens)
seq = nn.utils.rnn.pack_padded_sequence(
context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True
)
return self.lstm_sequence(seq)[0][unsort_ids]
dtype = context.dtype
# this is only needed for Torchscript to run in Triton
# (https://github.com/pytorch/pytorch/issues/89241)
with torch.cuda.amp.autocast(enabled=False):
ret = self.lstm_tensor(context.to(dtype=torch.float32), lens_sorted, enforce_sorted=True)
return ret[0].to(dtype=dtype)[unsort_ids]


class ConvLSTMLinear(BiLSTM):
class ConvLSTMLinear(nn.Module):
def __init__(
self,
in_dim=None,
Expand All @@ -162,7 +152,8 @@ def __init__(
use_partial_padding=False,
norm_fn=None,
):
super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1)
super(ConvLSTMLinear, self).__init__()
self.bilstm = BiLSTM(n_channels, int(n_channels // 2), 1)
self.convolutions = nn.ModuleList()

if n_layers > 0:
Expand Down Expand Up @@ -193,27 +184,16 @@ def __init__(
if out_dim is not None:
self.dense = nn.Linear(n_channels, out_dim)

def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
def forward(self, context: Tensor, lens: Tensor) -> Tensor:
mask = get_mask_from_lengths_and_val(lens, context)
mask = mask.to(dtype=context.dtype).unsqueeze(1)
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context, mask)))

context = context.transpose(1, 2)
seq = torch.nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
return seq

def forward(self, context: Tensor, lens: Tensor) -> Tensor:
context, lens, unsort_ids = sort_tensor(context, lens)
seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True)
context, _ = self.lstm_sequence(seq)
context = context[unsort_ids]

# Apply Bidirectional LSTM
context = self.bilstm(context, lens)
if self.dense is not None:
context = self.dense(context).permute(0, 2, 1)

return context


Expand Down
8 changes: 3 additions & 5 deletions nemo/collections/tts/modules/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,7 @@ def preprocess_context(self, context, speaker_vecs, out_lens, f0, energy_avg):
context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)

unfolded_out_lens = out_lens // self.n_group_size
context_lstm_padded_output = self.context_lstm.sort_and_lstm_tensor(
context_w_spkvec.transpose(1, 2), unfolded_out_lens
)
context_lstm_padded_output = self.context_lstm(context_w_spkvec.transpose(1, 2), unfolded_out_lens)
context_w_spkvec = context_lstm_padded_output.transpose(1, 2)

if not self.context_lstm_w_f0_and_energy:
Expand Down Expand Up @@ -772,8 +770,8 @@ def input_example(self, max_batch=1, max_dim=256):
"""
par = next(self.parameters())
sz = (max_batch, max_dim)
inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int)
inp = torch.randint(16, 32, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(max_dim // 4, max_dim // 2, (max_batch,), device=par.device, dtype=torch.int)
speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64)
inputs = {
'text': inp,
Expand Down
36 changes: 24 additions & 12 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from contextlib import nullcontext
from enum import Enum
from typing import Callable, Dict, List, Optional, Type
from typing import Callable, Dict, Optional, Type

import onnx
import torch
Expand Down Expand Up @@ -135,14 +135,16 @@ def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list):


def verify_torchscript(model, output, input_examples, input_names, check_tolerance=0.01):
ts_model = torch.jit.load(output)

all_good = True
for input_example in input_examples:
input_list, input_dict = parse_input_example(input_example)
output_example = model.forward(*input_list, **input_dict)

all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance)
# We disable autocast here to make sure exported TS will run under Triton or other C++ env
with torch.cuda.amp.autocast(enabled=False):
ts_model = torch.jit.load(output)
all_good = all_good and run_ts_and_compare(
ts_model, input_list, input_dict, output_example, check_tolerance
)
status = "SUCCESS" if all_good else "FAIL"
logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status)
return all_good
Expand Down Expand Up @@ -183,9 +185,15 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c
if torch.is_tensor(expected):
tout = out.to('cpu')
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}")
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
this_good = False
if not this_good:
logging.info(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}")
all_good = False
logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}")
return all_good


Expand All @@ -199,9 +207,15 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
if torch.is_tensor(expected):
tout = torch.from_numpy(out)
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}")
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
all_good = False
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
this_good = False
if not this_good:
logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}")
all_good = False
return all_good


Expand Down Expand Up @@ -387,8 +401,7 @@ def replace_modules(


def script_module(m: nn.Module):
m1 = torch.jit.script(m)
return m1
return torch.jit.script(m)


default_replacements = {
Expand All @@ -399,7 +412,6 @@ def script_module(m: nn.Module):

script_replacements = {
"BiLSTM": script_module,
"ConvLSTMLinear": script_module,
}


Expand Down
4 changes: 3 additions & 1 deletion tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tempfile

import pytest
import torch
from omegaconf import OmegaConf

from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel
Expand Down Expand Up @@ -79,4 +80,5 @@ def test_RadTTSModel_export_to_torchscript(self, radtts_model):
model = radtts_model.cuda()
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'rad.ts')
model.export(output=filename, verbose=True, check_trace=True)
with torch.cuda.amp.autocast(enabled=True):
model.export(output=filename, verbose=True, check_trace=True)

0 comments on commit 08bc700

Please sign in to comment.