# Create TTS VITS model and TorchScript model

- 이 노트북은 VITS 모델을 생성 및  TorchScript 형태로 변화하고 추론 테스트까지를 합니다.

## 1.  Setup environment
사용하는 패키지는 import 시점에 다시 재로딩 합니다.

In [1]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.append(os.path.abspath("./vits"))
# sys.path.append('./vits')
for i in sys.path:
    print(i)

/home/ec2-user/SageMaker/.cs/conda/envs/vits-conda-py310/lib/python310.zip
/home/ec2-user/SageMaker/.cs/conda/envs/vits-conda-py310/lib/python3.10
/home/ec2-user/SageMaker/.cs/conda/envs/vits-conda-py310/lib/python3.10/lib-dynload

/home/ec2-user/SageMaker/.cs/conda/envs/vits-conda-py310/lib/python3.10/site-packages
/home/ec2-user/SageMaker/lab/00-trition-tts-vits/02-tts-vits-docker-trition/vits


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence

from scipy.io.wavfile import write




## 2. Load LJ Speech Model
- [Important] Downlaod models to vits/models/
    - https://drive.google.com/drive/folders/1ksarh-cJf3F5eKJjLVWY0X1j1qsQqiS2

In [16]:
def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

hps = utils.get_hparams_from_file("vits/configs/ljs_base.json")

In [4]:
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model).cuda()
_ = net_g.eval()

_ = utils.load_checkpoint("vits/models/pretrained_ljs.pth", net_g, None)

## 3. Create input text, inference and play it

In [5]:
stn_tst = get_text("Claude is AI for all of us. Whether you're brainstorming alone or building with a team of thousands, Claude is here to help", hps)
# stn_tst = get_text("VITS is Awesome!", hps)
with torch.no_grad():
    x_tst = stn_tst.cuda().unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
    result = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)
    audio = result[0][0,0].data.cpu().float().numpy()    
    # audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()

ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))

## 4. Analyze output

In [6]:
len(result), result[0].shape, result[1].shape, result[2].shape

(4,
 torch.Size([1, 1, 171264]),
 torch.Size([1, 1, 669, 257]),
 torch.Size([1, 1, 669]))

In [7]:
result[0][0,0].data.cpu().float().numpy()    

array([-0.00093722, -0.00093612, -0.00129202, ..., -0.00119417,
       -0.00102345, -0.00069893], dtype=float32)

## 5. Prepare Trace model

### Create Wrapped Model

In [18]:
class WrappedModel(torch.nn.Module):
    def __init__(self, model: SynthesizerTrn):
        super().__init__()
        self.model = model

    def forward(
        self,
        x,
        x_lengths,
        noise_scale=0.667,
        length_scale=1.0,
        noise_scale_w=0.8,
    ):
        return self.model.infer(
            x=x,
            x_lengths=x_lengths,
            noise_scale=noise_scale,
            length_scale=length_scale,
            noise_scale_w=noise_scale_w,
        )[0]

In [19]:
model = WrappedModel(net_g)
model = model.eval()


### Create Dummpy inputs 

In [20]:
x = torch.randint(low=1, high=10, size=(10,), dtype=torch.int64)
x = x.unsqueeze(0).cuda()

x_length = torch.tensor([x.shape[1]], dtype=torch.int64).cuda()

print("x: \n", x)
print("x: ", x.shape)
print("x_length shape: ", x_length.shape)
print("x_length: ", x_length)


noise_scale = torch.tensor([1], dtype=torch.float32).cuda()
length_scale = torch.tensor([1], dtype=torch.float32).cuda()
noise_scale_w = torch.tensor([1], dtype=torch.float32).cuda()

x: 
 tensor([[2, 4, 6, 1, 3, 1, 8, 7, 8, 1]], device='cuda:0')
x:  torch.Size([1, 10])
x_length shape:  torch.Size([1])
x_length:  tensor([10], device='cuda:0')


## 6. Create Trace Model 

In [22]:
dummy_input = (x, x_length, noise_scale, length_scale, noise_scale_w)
trace_model = torch.jit.trace(model, dummy_input)


The values for attribute 'shape' do not match: torch.Size([1, 1, 6912]) != torch.Size([1, 1, 8448]).
  _check_trace(


### Save model and test loading

In [23]:
workspace_folder = "workspace"
os.makedirs(workspace_folder, exist_ok=True)
model_path = "trace_vits.pt"
model_path = os.path.join(workspace_folder, model_path)

trace_model.save(model_path)
print("#### Load Test ####")    
loaded_m = torch.jit.load(model_path)    
# print(loaded_m.code)    


#### Load Test ####


### Inference on the loaded model

In [26]:
result = loaded_m(x_tst, x_tst_lengths, noise_scale, length_scale, noise_scale_w)
print("Result : \n", result.shape)
result = result[0,0].data.cpu().float().numpy()
print("Output value : \n", result.shape)

Result : 
 torch.Size([1, 1, 177408])
Output value : 
 (177408,)


### Play output array data

In [27]:
ipd.display(ipd.Audio(result, rate=hps.data.sampling_rate, normalize=False))