In [1]:
from models.modules import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import librosa
import torchaudio
import torchaudio.transforms as T
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader

import models
from models.modules import get_extra_padding_for_conv1d
from models.unet import DiffusionUNet, create_diffusion_model
from models.utils import LogMelSpectrogram, count_parameters, get_padding_sample
from models.discriminator import EnsembleDiscriminator
from training.dataset_vctk import DenoiserDataset, collate_fn_latents
from models.lldm_architecture import WaveLLDM

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import soundfile as sf

In [2]:
device = "cuda"

backbone = models.ConvNeXtEncoder(
    input_channels=160,
    depths=[3, 3, 9, 3],
    dims=[128, 256, 384, 512],
    drop_path_rate=0.2,
    kernel_size=7
).to(device)

head = models.HiFiGANGenerator(
    hop_length=512,
    upsample_rates=[8, 8, 2, 2, 2],
    upsample_kernel_sizes=[16, 16, 4, 4, 4],
    resblock_kernel_sizes=[3, 7, 11],
    resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
    num_mels=512,
    upsample_initial_channel=512,
    pre_conv_kernel_size=13,
    post_conv_kernel_size=13
).to(device)

quantizer = models.DownsampleFSQ(
    input_dim=512,
    n_groups=8,
    n_codebooks=8,
    levels=[8, 8, 8, 6, 5],
    downsample_factor=[2, 2]
).to(device)

spec_trans = LogMelSpectrogram(
    sample_rate=44100,
    n_mels=160,
    n_fft=2048,
    hop_length=512,
    win_length=2048
).to(device)

ffgan = models.FireflyArchitecture(
    backbone=backbone,
    head=head,
    quantizer=quantizer,
    spec_transform=spec_trans
)

unet = create_diffusion_model(
    in_channels=1024,
    base_channels=32,
    out_channels=512
).to(device)

In [4]:
# Load the 1st-stage model state dicts
ffgan.load_state_dict(torch.load("./pretrained_models/generator_step_142465.pth"))

  ffgan.load_state_dict(torch.load("./pretrained_models/generator_step_142465.pth"))


<All keys matched successfully>

In [5]:
spec_trans_cpu = LogMelSpectrogram(
    sample_rate=44100,
    n_mels=160,
    n_fft=2048,
    hop_length=512,
    win_length=2048
).to("cpu")

encoder_cpu = models.ConvNeXtEncoder(
    input_channels=160,
    depths=[3, 3, 9, 3],
    dims=[128, 256, 384, 512],
    drop_path_rate=0.2,
    kernel_size=7
).to("cpu")

quantizer_cpu = models.DownsampleFSQ(
    input_dim=512,
    n_groups=8,
    n_codebooks=8,
    levels=[8, 8, 8, 6, 5],
    downsample_factor=[2, 2]
).to("cpu")

decoder_cpu = models.HiFiGANGenerator(
    hop_length=512,
    upsample_rates=[8, 8, 2, 2, 2],
    upsample_kernel_sizes=[16, 16, 4, 4, 4],
    resblock_kernel_sizes=[3, 7, 11],
    resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
    num_mels=512,
    upsample_initial_channel=512,
    pre_conv_kernel_size=13,
    post_conv_kernel_size=13
).to("cpu")

encoder_cpu.load_state_dict(ffgan.backbone.state_dict())
quantizer_cpu.load_state_dict(ffgan.quantizer.state_dict())
decoder_cpu.load_state_dict(ffgan.head.state_dict())

<All keys matched successfully>

In [6]:
train_ds = DenoiserDataset(
    "./data/voicebank_demand_56spk/clean_speech_audios/train/",
    "./data/voicebank_demand_56spk/noisy_speech_audios/train/",
    True,
    stage=3,
    spec_trans=spec_trans_cpu,
    encoder=encoder_cpu,
    quantizer=quantizer_cpu,
    device="cuda"
)

val_ds = DenoiserDataset(
    "./data/voicebank_demand_56spk/clean_speech_audios/test/",
    "./data/voicebank_demand_56spk/noisy_speech_audios/test/",
    True,
    stage=3,
    spec_trans=spec_trans_cpu,
    encoder=encoder_cpu,
    quantizer=quantizer_cpu,
    device="cuda"
)

In [7]:
wavelldm = WaveLLDM(
    p_estimator=unet,
    learn_logvar=False,
    encoder=ffgan.backbone,
    quantizer=ffgan.quantizer,
    decoder=ffgan.head,
    beta_scheduler="cosine"
).to(device)

In [8]:
latent_train_dataloader = DataLoader(
    train_ds,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn_latents
)

latent_val_dataloader = DataLoader(
    val_ds,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn_latents
)

In [None]:
for i, data in enumerate(latent_train_dataloader):
    # print(data["melspec_lengths"][0])
    # loss, loss_dict = wavelldm(data)
    # print("Loss: ", loss.item())
    # print(loss_dict)

    wavelldm.log_reconstruction(Summary)

    break

168
Loss:  1.2705508470535278
{'train/loss_simple': tensor(1.2627, device='cuda:0', grad_fn=<MeanBackward0>), 'train/loss_vlb': tensor(0.0079, device='cuda:0', grad_fn=<MeanBackward0>), 'train/loss': tensor(1.2706, device='cuda:0', grad_fn=<AddBackward0>)}


In [9]:
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
import gc

class WaveLLDMTrainer:
    def __init__(
        self,
        model: WaveLLDM,
        train_dataloader,
        val_dataloader=None,
        epochs: int = 300,
        save_dir: str = "./checkpoints",
        log_dir: str = "./logs",
        save_every: int = 5,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.model = model.to(device)
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.epochs = epochs
        self.save_dir = save_dir
        self.log_dir = log_dir
        self.save_every = save_every
        self.device = device

        os.makedirs(self.save_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)

        self.writer = SummaryWriter(log_dir=self.log_dir)
    
    def train(self):
        for epoch in range(self.epochs):
            self.model.train()
            train_loss = 0.0
            train_steps = 0

            with tqdm(total=len(self.train_dataloader), desc=f"Epoch {epoch+1}/{self.epochs}", unit="batch") as pbar:
                for idx, batch in enumerate(self.train_dataloader):
                    loss, loss_dict = self.model.train_step(batch)
                    train_loss += loss.item()
                    train_steps += 1

                    avg_train_loss_on_fly = train_loss / (idx + 1)

                    pbar.set_postfix({
                        'Loss': avg_train_loss_on_fly
                    })
                    pbar.update(1)

                    if idx % 100 == 0:
                        torch.cuda.empty_cache()
                        gc.collect()
                    
                    self.model.log_to_tensorboard(
                        self.writer, loss_dict, train_steps + epoch * len(self.train_dataloader), prefix="train", batch=batch
                    )
            
            avg_train_loss = train_loss / train_steps
            print(f"Epoch {epoch+1}/{self.epochs}, Average Train Loss: {avg_train_loss:.4f}")

            if self.val_dataloader is not None:
                val_loss = self.model.validate(epoch, self.val_dataloader, self.writer)
                print(f"Epoch {epoch+1}/{self.epochs}, Average Val Loss: {val_loss:.4f}")
            
            if (epoch + 1) % self.save_every == 0:
                self.model.save_checkpoint(epoch + 1)
        
        self.writer.close()

In [10]:
trainer_wavelldm = WaveLLDMTrainer(
    model=wavelldm,
    train_dataloader=latent_train_dataloader,
    val_dataloader=latent_val_dataloader,
    epochs=300,
    save_dir="./checkpoints",
    log_dir="./logs/wavelldm",
    save_every=5,
    device=device
)

In [11]:
trainer_wavelldm.train()

Epoch 1/300:   0%|          | 0/722 [00:00<?, ?batch/s]

KeyboardInterrupt: 