In [None]:
# dep: git clone https://github.com/NVIDIA/waveglow
# from waveglow import glow
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import DataLoader, Dataset
import os

class WaveGlowDataset(Dataset):
    def __init__(self, wav_dir, transform=None):
        self.wav_dir = wav_dir
        self.wav_files = sorted([f for f in os.listdir(wav_dir) if f.endswith(".wav")])
        self.transform = transform

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self, idx):
        wav_path = os.path.join(self.wav_dir, self.wav_files[idx])
        waveform, sample_rate = torchaudio.load(wav_path)
        if self.transform:
            mel_spec = self.transform(waveform)
        else:
            mel_spec = torchaudio.transforms.MelSpectrogram()(waveform)
        return mel_spec, waveform

class Invertible1x1Conv(nn.Module):
    def __init__(self, num_channels):
        super(Invertible1x1Conv, self).__init__()
        w_init = torch.qr(torch.randn(num_channels, num_channels))[0]
        self.weight = nn.Parameter(w_init)

    def forward(self, z):
        batch_size, num_channels, length = z.size()
        weight = self.weight.view(num_channels, num_channels, 1)
        return torch.nn.functional.conv1d(z, weight)

    def inverse(self, z):
        batch_size, num_channels, length = z.size()
        weight_inv = torch.inverse(self.weight).view(num_channels, num_channels, 1)
        return torch.nn.functional.conv1d(z, weight_inv)

class AffineCoupling(nn.Module):
    def __init__(self, num_channels):
        super(AffineCoupling, self).__init__()
        self.net = nn.Sequential(
            nn.Conv1d(num_channels // 2, num_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(num_channels, num_channels // 2, kernel_size=3, padding=1)
        )

    def forward(self, z):
        z1, z2 = z.chunk(2, 1)
        log_s, t = self.net(z1).chunk(2, 1)
        s = torch.exp(log_s)
        z2 = s * z2 + t
        return torch.cat([z1, z2], 1)

    def inverse(self, z):
        z1, z2 = z.chunk(2, 1)
        log_s, t = self.net(z1).chunk(2, 1)
        s = torch.exp(log_s)
        z2 = (z2 - t) / s
        return torch.cat([z1, z2], 1)

class WN(nn.Module):
    def __init__(self, num_channels):
        super(WN, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1, dilation=2 ** i)
            for i in range(6)
        ])

    def forward(self, x):
        for conv in self.layers:
            x = torch.nn.functional.relu(conv(x))
        return x

class WaveGlow(nn.Module):
    def __init__(self, num_channels, num_flows):
        super(WaveGlow, self).__init__()
        self.num_flows = num_flows
        self.flows = nn.ModuleList()

        for _ in range(num_flows):
            self.flows.append(Invertible1x1Conv(num_channels))
            self.flows.append(AffineCoupling(num_channels))
            self.flows.append(WN(num_channels))

    def forward(self, z):
        log_det_jacobian = 0
        for flow in self.flows:
            if isinstance(flow, Invertible1x1Conv):
                z = flow(z)
            elif isinstance(flow, AffineCoupling):
                z = flow(z)
            elif isinstance(flow, WN):
                z = flow(z)
        return z, log_det_jacobian

    def inverse(self, z):
        for flow in reversed(self.flows):
            if isinstance(flow, Invertible1x1Conv):
                z = flow.inverse(z)
            elif isinstance(flow, AffineCoupling):
                z = flow.inverse(z)
            elif isinstance(flow, WN):
                z = flow(z)
        return z

def train_waveglow(wav_dir, num_epochs=10, batch_size=16, learning_rate=1e-4, device='cuda'):
    dataset = WaveGlowDataset(wav_dir=wav_dir, transform=torchaudio.transforms.MelSpectrogram())
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    num_channels = 256
    num_flows = 12
    waveglow = WaveGlow(num_channels=num_channels, num_flows=num_flows).to(device)
    optimizer = optim.Adam(waveglow.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        waveglow.train()
        running_loss = 0.0

        for spectrograms, waveforms in dataloader:
            spectrograms = spectrograms.to(device)
            waveforms = waveforms.to(device)

            z, log_det_jacobian = waveglow(spectrograms)

            # Loss (Negative Log-Likelihood loss)
            loss = 0.5 * torch.sum(z**2) - log_det_jacobian

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(dataloader):.4f}')

    torch.save(waveglow.state_dict(), 'waveglow.pth')
    print("Training complete")

train_waveglow('../unpacked_data', num_epochs=10, batch_size=16)

