In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import seaborn as sns
from utils import weights_init_xavier_uniform
from dataset import load_ljspeech_dataset
from mbmelgan_model import Discriminator, MultiBandGenerator, FullBandGenerator
from trainers import MBMelGANTrainer, FBMelGANTrainer

In [2]:
from config import Config
config = Config()

In [3]:
trainset, _ = load_ljspeech_dataset(config)
# train_loader = DataLoader(trainset, batch_size=config.batch_size, pin_memory=False, shuffle=True)
train_loader = DataLoader(trainset, batch_size=config.batch_size_fb, pin_memory=True, shuffle=True, num_workers=config.num_workers)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [5]:
G = FullBandGenerator(config.n_mels)
D = Discriminator()
weights_init_xavier_uniform(G)
weights_init_xavier_uniform(D)

In [6]:
g_optimizer_builder = lambda model:torch.optim.Adam(G.parameters(), lr=config.g_lr, betas=config.adam_betas)
d_optimizer_builder = lambda model:torch.optim.Adam(D.parameters(), lr=config.d_lr, betas=config.adam_betas)
# 'The learning rate of all models is halved every 100K steps until 1e − 6' from paper.

g_scheduler_builder = lambda opt: torch.optim.lr_scheduler.LambdaLR(optimizer=opt, lr_lambda=lambda e: max(1e-6/config.g_lr, (0.5)**(e//100000)))
d_scheduler_builder = lambda opt: torch.optim.lr_scheduler.LambdaLR(optimizer=opt, lr_lambda=lambda e: max(1e-6/config.d_lr, (0.5)**(e//100000)))

In [7]:
trainer = FBMelGANTrainer(G, D,
                          g_optimizer_builder,
                          d_optimizer_builder,
                          g_scheduler_builder,
                          d_scheduler_builder,
                          lambda_adv=config.lambda_adv,
                          checkpoint_dir=config.checkpoint_dir)

trainer.set_tqdm_for_notebook(True)

In [8]:
reset = True
if config.train_after is not None:
    trainer.load_data( config.train_after)
    reset = False

In [None]:
result = trainer.train(train_loader, train_only_g_till=config.train_generator_until, epochs=config.epochs, device=device, reset=reset, cp_filename=config.checkpoint_file_template_fb, cp_interval=config.checkpoint_interval)

In [None]:
sns.lineplot(x='epoch', y='g loss', data=result, label='Generator Loss')
sns.lineplot(x='epoch', y='d loss', data=result, label='Discriminator Loss')
plt.show()