In [1]:
from pathlib import Path
import argparse
import wandb
import json

import warnings
warnings.filterwarnings("ignore")

from torch.utils.data import DataLoader, random_split
import torch

from nv.spectrogram import MelSpectrogram
from nv.collate_fn import LJSpeechCollator
from nv.datasets import LJSpeechDataset
from nv.trainer import *
from nv.models import *
from nv.utils import *

In [2]:
config_path = "configs/config_v1.json"

with open(f"{config_path}") as file:
    config = AttrDict(json.load(file))

In [3]:
if config.use_wandb:
    wandb.init(project=config.wandb_project_name)

fix_seed(config)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if config.verbose:
    print(f"The training process will be performed on {device}.")
    print("Downloading and splitting the data.")

dataset = LJSpeechDataset(config.path_to_data)
train_size = int(config.train_ratio * len(dataset))
train_dataset, val_dataset = random_split(
    dataset, 
    [train_size, len(dataset) - train_size],
    generator=torch.Generator().manual_seed(config.seed)
)

train_dataloader = DataLoader(
    train_dataset, 
    collate_fn=LJSpeechCollator(),
    batch_size=config.batch_size, 
    #num_workers=config.num_workers
)

val_dataloader = DataLoader(
    val_dataset,
    collate_fn=LJSpeechCollator(),
    batch_size=config.batch_size, 
    #num_workers=config.num_workers
)

melspectrogramer = MelSpectrogram(config, for_loss=False).to(device)
melspectrogramer_for_loss = MelSpectrogram(config, for_loss=True).to(device)

if config.verbose:
    print("Initializing discriminator, generator, optimizers and lr_schedulers.")

generator = HiFiGenerator(config).to(device)
trainable_params_generator = filter(
    lambda param: param.requires_grad, generator.parameters()
)
optimizer_generator = torch.optim.AdamW(
    trainable_params_generator, 
    betas=(config.adam_beta_1, config.adam_beta_2), 
    weight_decay=config.weight_decay, 
    lr=config.learning_rate
) 
scheduler_generator = torch.optim.lr_scheduler.ExponentialLR(
    optimizer_generator, 
    gamma=config.gamma
) 

discriminator = HiFiDiscriminator(config).to(device) 
trainable_params_discriminator = filter(
    lambda param: param.requires_grad, discriminator.parameters()
)
optimizer_discriminator = torch.optim.AdamW(
    trainable_params_discriminator, 
    betas=(config.adam_beta_1, config.adam_beta_2), 
    weight_decay=config.weight_decay, 
    lr=config.learning_rate
) 
scheduler_discriminator = torch.optim.lr_scheduler.ExponentialLR(
    optimizer_discriminator, 
    gamma=config.gamma
) 

The training process will be performed on cpu.
Downloading and splitting the data.
Initializing discriminator, generator, optimizers and lr_schedulers.


In [4]:
from nv.trainer import *
from nv.losses import *

def train(
    config, 
    train_dataloader, 
    val_dataloader,
    generator, 
    optimizer_generator, 
    scheduler_generator, 
    discriminator, 
    optimizer_discriminator, 
    scheduler_discriminator, 
    melspectrogramer, 
    melspectrogramer_for_loss
):  
    history_val_melspec_loss = []
    epoch = 0

    #for epoch in tqdm(range(config.num_epoch)):
    while True:
        epoch += 1

        train_melspec_loss = train_epoch(
            config, train_dataloader,
            generator, optimizer_generator, scheduler_generator, 
            discriminator, optimizer_discriminator, scheduler_discriminator, 
            melspectrogramer, melspectrogramer_for_loss
        )

        val_melspec_loss = validate_epoch(
            config, val_dataloader,
            generator, optimizer_generator, scheduler_generator, 
            discriminator, optimizer_discriminator, scheduler_discriminator, 
            melspectrogramer, melspectrogramer_for_loss
        )

        history_val_melspec_loss.append(val_melspec_loss)

        if config.use_wandb:             
            wandb.log({
                "Epoch": epoch,
                "Global Train Melspectrogram Loss": train_melspec_loss,
                "Global Validation Melspectrogram Loss": val_melspec_loss
            })  
        
        #if val_melspec_loss <= min(history_val_melspec_loss):
        state = {
            "generator": generator.state_dict(),
            "generator_arch": type(generator).__name__,
            "optimizer_generator": optimizer_generator.state_dict(),
            "discriminator": discriminator.state_dict(),
            "discriminator_arch": type(discriminator).__name__,
            "optimizer_discriminator": optimizer_generator.state_dict(),
            "config": config
        }
        torch.save(state, config.path_to_save + "/best.pt")

In [None]:
def train_epoch(
    config, 
    train_dataloader,
    generator, 
    optimizer_generator, 
    scheduler_generator, 
    discriminator, 
    optimizer_discriminator, 
    scheduler_discriminator, 
    melspectrogramer, 
    melspectrogramer_for_loss
):
    generator.train()
    discriminator.train()

    adversarial_loss = AdversarialLoss()
    feature_loss = FeatureMatchingLoss()
    melspec_loss = MelSpectrogramLoss()

    discriminator_loss = DiscriminatorLoss()

    for batch in train_dataloader:
        batch = prepare_batch(batch, melspectrogramer, melspectrogramer_for_loss, device, for_training=True)

In [1]:
import torch, torchaudio
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, random_split
import json

In [2]:
import numpy as np
from typing import *
from torch.nn.utils.rnn import pad_sequence
import random

In [19]:
for batch in train_dataloader:
    break
    batch = prepare_batch(batch, melspectrogramer, aligner, config, device)

In [20]:
batch = prepare_batch(batch, melspectrogramer, melspectrogramer_loss, device, for_training=True)

  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]


In [37]:
adversarial_loss = AdversarialLoss()
feature_loss = FeatureMatchingLoss()
melspec_loss = MelSpectrogramLoss()

In [38]:
wav_real, melspec_real = batch.waveform, batch.melspec

wav_fake = generator(melspec_real)
melspec_fake = melspectrogramer_loss(wav_fake)

out_discr = discriminator(wav_real, wav_fake)

In [39]:
loss_adv = adversarial_loss(out_discr["outs_fake"])
loss_fm = feature_loss(out_discr["feature_maps_real"], out_discr["feature_maps_fake"])
loss_mel = melspec_loss(melspec_real, melspec_fake)

In [148]:
loss_generator = loss_adv + 2 * loss_fm + 45 * loss_mel
loss_generator.backward()

In [151]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f88f9217fd0>

In [152]:
loss_generator.backward()

  Variable._execution_engine.run_backward(


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
(batch.melspec_loss, melspec_fake)