In [2]:
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler  import OneCycleLR

from src.configs import FastSpeechConfig
from src.configs import TrainConfig

from src.util import BufferDataset
from src.util import download_buffer
from src.util import collate_fn_tensor
from src.util import seed_everything

from src.wandb_writer import WanDBWriter
from src.model import FastSpeech
from src.loss import FastSpeechLoss
from src.train import train

In [None]:
seed_everything(0xbebebe)
model_config = FastSpeechConfig()
train_config = TrainConfig()

In [None]:
download_buffer()
buffer = torch.load('saved_buffer.pkl')
for buf in buffer:
    buf['energy'] /= 488
    buf['pitch'] /= 862

dataset = BufferDataset(buffer)

training_loader = DataLoader(
    dataset,
    batch_size=train_config.batch_expand_size * train_config.batch_size,
    shuffle=True,
    collate_fn=collate_fn_tensor,
    drop_last=True,
    num_workers=0
)

In [None]:
model = FastSpeech(model_config)
model = model.to(train_config.device)

fastspeech_loss = FastSpeechLoss()
current_step = 0

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=train_config.learning_rate,
    betas=(0.9, 0.98),
    eps=1e-9)

scheduler = OneCycleLR(optimizer, **{
    "steps_per_epoch": len(training_loader) * train_config.batch_expand_size,
    "epochs": train_config.epochs,
    "anneal_strategy": "cos",
    "max_lr": train_config.learning_rate,
    "pct_start": 0.1
})
logger = WanDBWriter(train_config)

In [None]:
train(
    model=model,
    fastspeech_loss=fastspeech_loss,
    optimizer=optimizer,
    scheduler=scheduler,
    logger=logger,
    training_loader=training_loader,
    train_config=train_config
)

# synthesys

In [None]:
import waveglow
import utils

import os
import numpy as np


WaveGlow = utils.get_WaveGlow()
WaveGlow = WaveGlow.cuda()

model.load_state_dict(torch.load('checkpoint_36000.pth.tar', map_location='cuda:0')['model'])
model = model.eval()

In [None]:
def synthesis(model, text, alpha=1.0, palpha=1.0, ealpha=1.0, speaker_id=10):
    text = np.stack([text])
    src_pos = np.array([i+1 for i in range(text.shape[1])])
    src_pos = np.stack([src_pos])
    sequence = torch.from_numpy(text).long().to(train_config.device)
    src_pos = torch.from_numpy(src_pos).long().to(train_config.device)
    sid = torch.tensor([speaker_id]).to(train_config.device)
    
    with torch.no_grad():
        mel = model.forward(sequence, src_pos, alpha=alpha, speaker_id=sid, palpha=palpha, ealpha=ealpha)
      
    return mel[0].cpu().transpose(0, 1), mel[0].contiguous().transpose(1, 2)


def get_data(text):
    data_list = list(text.text_to_sequence(text, train_config.text_cleaners))
    return data_list


data = get_data(input('Enter your sentence: '))[0]

# parameters of speaker
speaker_id = 5
ealpha = 1.
palpha = 1.
speed  = 1.

_, mel_cuda = synthesis(
    model, data, speed,
    ealpha=ealpha, palpha=palpha,
    speaker_id=speaker_id
)
os.makedirs("results", exist_ok=True)
name = "sp={}_en={}_p={}_sid={}".format(
    speed, ealpha, palpha, speaker_id
)

waveglow.inference.inference(
    mel_cuda.to('cuda:0'), WaveGlow,
    f"results/{name}.wav"
)

display.Audio("results/{name}.wav")