In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader

import torchaudio
import torchaudio_augmentations as ta

import lightning as L
from lightning import LightningModule

import auraloss

import numpy as np

import pandas as pd

from typing import Optional

import os

import utility as U


In [None]:
DEFAULT_INPUT_SR = 16000
DEFAULT_LATENT_SR = 125 # Chosen because 16000 / 2^7 = 125, and we have an even number of 0.5x downsamples
DEFAULT_LATENT_CHANNELS = 16 # Seems to be a pretty standard value for this

DEFAULT_1D_KERNEL_SIZE = 7 # This seems to be standard practice for waveforms
DEFAULT_1D_PADDING = 3 # Padding necessary for kernel size 7 for exact halving of dimensions

DEFAULT_MAX_CHANNELS = 256

DEFAULT_AUDIO_DUR = 10 # In seconds
MAX_SEQ_LEN = 20000

In [None]:
class ELBO_Loss(nn.Module):
    def __init__(self, KL_weight=1e-3):
        super(ELBO_Loss, self).__init__()
        self.stft = auraloss.freq.MultiResolutionSTFTLoss()
        self.KL = KL_weight

    def forward(self, recon_x: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        STFT = self.stft(recon_x, x)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return STFT + self.KL * KLD

In [None]:
class SelfAttention(nn.Module):
    def __init__(self,
                 channels: int,
                 n_heads: int):
        """
        Multiheaded self-attention (with residual connections)

        Args:
            channels (int): Channels for input sequence
            n_heads (int): Number of attention heads
        """
        super(SelfAttention, self).__init__()
        self.dim = channels
        self.attn = nn.MultiheadAttention(channels, n_heads, batch_first=True)


    def _posn_encoding(self, seq_len: int) -> torch.Tensor:
        """
        Positional encoding
        Args:
            seq_len (int): Sequence length

        Returns:
            torch.Tensor: Positional encoding of shape [seq_len, dim]
        """
        position = torch.arange(0, seq_len, 1).unsqueeze(0).unsqueeze(-1)
        denom = torch.pow(10000, -2 * torch.arange(0, self.dim, 1) / self.dim).unsqueeze(0).unsqueeze(0)
        pe = torch.zeros((1, seq_len, self.dim))
        pe[:, :, 0::2] = torch.sin(position * denom[:, :, 0::2])
        pe[:, :, 1::2] = torch.cos(position * denom[:, :, 1::2])
        self.register_buffer('pe', pe)
        return pe

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for self-attention (with residual connections)

        B = batch
        C = channels
        T = time
        Args:
            x (torch.Tensor): Sequence tensor of shape [B, C, T]

        Returns:
            torch.Tensor: Tensor of shape [B, C, T]
        """
        x = x.permute(0, 2, 1)  # [B, T, C]
        B, T, C = x.shape

        embeddings = self._posn_encoding(T).to(x.device)  # [B, T, C]
        attn_in = x + embeddings

        attn_out, _ = self.attn(attn_in, attn_in, attn_in, need_weights=False)
        out = x + attn_out  # Residual connection
        return out.permute(0, 2, 1)  # Back to [B, C, T]

class DownsampleLayer(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation: str='gelu'):
        super(DownsampleLayer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.Conv1d(in_channels, out_channels, DEFAULT_1D_KERNEL_SIZE, stride=2, padding=DEFAULT_1D_PADDING)
        self.norm = nn.GroupNorm(out_channels // 4, out_channels)
        self.activation = U.get_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.norm(x)
        x = self.activation(x)
        return x

class VAE_Encoder(nn.Module):
    def __init__(self,
                 input_channels: int,
                 latent_channels: int=DEFAULT_LATENT_CHANNELS,
                 input_sr: int=DEFAULT_INPUT_SR,
                 latent_sr: int=DEFAULT_LATENT_SR):
        """
        Conditional Variational Autoencoder Encoder
        Args:
            input_channels (int): Number of channels for input audio waveforms (ex. stereo vs. mono)
            latent_channels (int): Number of channels for latent audio waveforms
            input_sr (int): Input audio waveform sample rate (16000Hz default)
            latent_sr (int): Target Latent audio sample rate (125Hz default) - No guarantees it'll actually reach this
        """
        super(VAE_Encoder, self).__init__()

        self.input_channels = input_channels,
        self.latent_channels = latent_channels
        self.input_sr = input_sr
        self.latent_sr = latent_sr
        self.activation = 'gelu'

        # Input dimension must be some power of 2 multiple of latent dim
        self.n_downsamples = np.ceil(np.log2(self.input_sr / self.latent_sr)).astype(np.int32)
        assert (2 ** self.n_downsamples) * latent_sr == self.input_sr

        starter_channels = 16
        layers = [
            nn.Conv1d(input_channels, starter_channels, DEFAULT_1D_KERNEL_SIZE, stride=1, padding=DEFAULT_1D_PADDING),
            U.get_activation_module('gelu'),
        ]

        # Channels go from 16 -> 32 -> 64 -> DEFAULT_MAX_CHANNELS ... n_downsamples layers
        in_ch = starter_channels
        for i in range(self.n_downsamples):
            out_ch = min(in_ch * 2, DEFAULT_MAX_CHANNELS)
            layers.append(DownsampleLayer(in_ch, out_ch))  # Downsample by factor of 2
            in_ch = out_ch

        layers.append(SelfAttention(in_ch, 4))

        self.layers = nn.Sequential(*layers)

        self.mu_proj = nn.Sequential(
            nn.Conv1d(in_ch, in_ch, kernel_size=DEFAULT_1D_KERNEL_SIZE, padding=DEFAULT_1D_PADDING),
            U.get_activation_module('gelu'),
            nn.Conv1d(in_ch, latent_channels, kernel_size=1)
        )

        self.logvar_proj = nn.Sequential(
            nn.Conv1d(in_ch, in_ch, kernel_size=DEFAULT_1D_KERNEL_SIZE, padding=DEFAULT_1D_PADDING),
            U.get_activation_module('gelu'),
            nn.Conv1d(in_ch, latent_channels, kernel_size=1)
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        VAE encoder forward pass, input waveforms are expected to be of the correct sample rate (from VAE encoder constructor)

        B = batch size
        C = channels
        T = timesteps

        L = latent space channels
        T' = T * latent_sr / input_sr
        Args:
            x (torch.Tensor): Batch of waveforms of shape [B, C, T]
        Returns:
            Parameters to a diagonal Gaussian in latent space
            torch.Tensor: Latent space tensor of shape [B, L, T'] mean
            torch.Tensor: Latent space tensor of shape [B, L, T'] log variances
        """
        x = self.layers(x)
        return self.mu_proj(x), self.logvar_proj(x)

class UpsampleLayer(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, activation: str='gelu'):
        super(UpsampleLayer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.upsample = nn.ConvTranspose1d(in_channels, out_channels, DEFAULT_1D_KERNEL_SIZE, stride=2, padding=DEFAULT_1D_PADDING, output_padding=1)
        self.conv = nn.Conv1d(out_channels, out_channels, DEFAULT_1D_KERNEL_SIZE, stride=1, padding=DEFAULT_1D_PADDING)
        self.norm = nn.GroupNorm(out_channels // 4, out_channels)
        self.activation = U.get_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.upsample(x)
        x = self.activation(x)
        x = self.conv(x) + x
        x = self.norm(x)
        x = self.activation(x)
        return x

class VAE_Decoder(nn.Module):
    def __init__(self,
                 input_channels: int,
                 latent_channels: int=DEFAULT_LATENT_CHANNELS,
                 input_sr: int=DEFAULT_INPUT_SR,
                 latent_sr: int=DEFAULT_LATENT_SR):
        """
        Conditional Variational Autoencoder Decoder
        Args:
            input_channels (int): Number of channels for input audio waveforms (ex. stereo vs. mono)
            latent_channels (int): Number of channels for latent audio waveforms
            input_sr (int): Input audio waveform sample rate (16000Hz default)
            latent_sr (int): Target Latent audio sample rate (125Hz default) - No guarantees it'll actually reach this
        """
        super(VAE_Decoder, self).__init__()

        self.input_channels = input_channels
        self.latent_channels = latent_channels
        self.input_sr = input_sr
        self.latent_sr = latent_sr
        self.activation = 'gelu'

        # Input dimensions must be some power of 2 multiple of latent dim
        self.n_upsamples = np.ceil(np.log2(self.input_sr / self.latent_sr)).astype(np.int32)
        assert (2 ** self.n_upsamples) * latent_sr == self.input_sr

        channels = DEFAULT_MAX_CHANNELS
        layers = [
            nn.Conv1d(latent_channels, channels, DEFAULT_1D_KERNEL_SIZE, stride=1, padding=DEFAULT_1D_PADDING),
            U.get_activation_module('gelu'),
        ]

        layers.append(SelfAttention(channels, 4))

        for i in range(self.n_upsamples):
            layers.append(UpsampleLayer(channels, channels))

        layers.append(nn.Conv1d(channels, input_channels, kernel_size=1))

        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        VAE decoder forward pass, input latent space waveforms and outputs waveforms in original input space

        B = batch size
        Z = latent channels
        T = timesteps
        Args:
            x (torch.Tensor): Batch of latent waveforms of shape [B, Z, T']

        Returns:
            torch.Tensor: Reconstruction of waveforms in input space of shape [B, C, T]
        """
        return self.layers(x)

class VAE(nn.Module):
    def __init__(self, audio_channels: int):
        super(VAE, self).__init__()
        self.channels = audio_channels
        self.encoder = VAE_Encoder(audio_channels)
        self.decoder = VAE_Decoder(audio_channels)
        self.latent_dim = self.decoder.latent_channels
        self.latent_sr = self.decoder.latent_sr

    def _sample(self, mu: torch.Tensor, log_var: torch.Tensor):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def generate(self, n_samples: int=1) -> torch.Tensor:
        z = torch.randn([n_samples, self.latent_dim, self.latent_sr * DEFAULT_AUDIO_DUR])
        audio = self.decoder(z)
        return audio

    def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Full VAE encoder + decoder forward pass

        Args:
            x (torch.Tensor): Batch of waveforms of shape [B, C, T]

        Returns:
            torch.Tensor: Reconstruction of waveforms in input space of shape [B, C, T]
            torch.Tensor: Mean of Gaussian distribution over latent space
            torch.Tensor: Log variance of Gaussian distribution over latent space

        """
        mu, log_var = self.encoder(input)
        sample = self._sample(mu, log_var)
        reconstruction = self.decoder(sample)
        return reconstruction, mu, log_var

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_channels: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_channels, 32, 15, stride=4, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, 15, stride=4, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, 15, stride=4, padding=7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 1, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        return self.net(x).mean(dim=-1)  # shape [B, 1]

class AudioVAEGAN(LightningModule):
    def __init__(self, channels: int, kl_weight: float = 1e-3, adv_weight: float = 1.0, lr=1e-4):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        self.vae = VAE(channels)
        self.discriminator = Discriminator(channels)

        self.recon_loss = ELBO_Loss(kl_weight)
        self.adv_weight = adv_weight

    def forward(self, x):
        return self.vae(x)

    def adversarial_loss(self, pred, target_is_real=True):
        target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
        return F.binary_cross_entropy_with_logits(pred, target)

    def training_step(self, batch, batch_idx):
        real = batch

        opt_vae, opt_disc = self.optimizers()

        ### === Train Generator (VAE Decoder) === ###
        self.toggle_optimizer(opt_vae)
        recon, mu, logvar = self.vae(real)

        elbo = self.recon_loss(recon, real, mu, logvar)
        d_fake = self.discriminator(recon)
        adv_loss = self.adversarial_loss(d_fake, target_is_real=True)
        total_gen_loss = elbo + self.adv_weight * adv_loss

        self.manual_backward(total_gen_loss)
        opt_vae.step()
        opt_vae.zero_grad()
        self.untoggle_optimizer(opt_vae)

        ### === Train Discriminator === ###
        self.toggle_optimizer(opt_disc)
        with torch.no_grad():
            recon_detached, _, _ = self.vae(real)

        d_real = self.discriminator(real)
        d_fake = self.discriminator(recon_detached)

        real_loss = self.adversarial_loss(d_real, target_is_real=True)
        fake_loss = self.adversarial_loss(d_fake, target_is_real=False)
        d_loss = 0.5 * (real_loss + fake_loss)

        self.manual_backward(d_loss)
        opt_disc.step()
        opt_disc.zero_grad()
        self.untoggle_optimizer(opt_disc)

        ### === Logging === ###
        self.log_dict({
            "gen_elbo": elbo,
            "gen_adv": adv_loss,
            "gen_total": total_gen_loss,
            "disc_loss": d_loss
        }, prog_bar=True, on_step=True, on_epoch=True)

    def generate(self, n_samples: int):
        return self.vae.generate(n_samples)

    def configure_optimizers(self):
        opt_g = torch.optim.Adam(self.vae.parameters(), lr=self.hparams.lr)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr * 2)
        return [opt_g, opt_d]


In [14]:
class AudioVAE(LightningModule):
    def __init__(self, channels: int, kl_weight: float = 1e-3, lr=1e-4):
        super(AudioVAE, self).__init__()
        self.vae = VAE(channels)
        self.loss = ELBO_Loss(kl_weight)
        self.lr = lr

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.vae(x)

    def training_step(self, batch, batch_idx=None, dataloader_idx=None) -> torch.Tensor:
        reconstruction, mu, log_var = self.vae(batch)
        loss = self.loss(reconstruction, batch, mu, log_var)
        self.log('training_elbo_loss', loss, prog_bar=True)
        return loss

    def generate(self, n_samples: int):
        return self.vae.generate(n_samples)

    def configure_optimizers(self):
        return optim.Adam(self.vae.parameters(), self.lr)

#### Dataset

In [None]:
class GTZANAudioDataset(Dataset):
    def __init__(self, path: str, sample_rate=16000, duration=10):
        """
        Expects a path to .../Data - i.e. the path should end with "Data"
        """
        super().__init__()
        self.path = path
        if not os.path.exists(path):
            import requests
            import zipfile
            # Stream the download to avoid loading the whole file into memory
            with requests.get('https://www.kaggle.com/api/v1/datasets/download/andradaolteanu/gtzan-dataset-music-genre-classification', stream=True) as r:
                r.raise_for_status()  # Raise an error on bad status
                with open('gtzan.zip', 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)

            # Ensure the extract_to directory exists
            os.makedirs(path, exist_ok=True)

            # Extract the ZIP file
            with zipfile.ZipFile('gtzan.zip', 'r') as zip_ref:
                path = os.path.pardir(path)
                zip_ref.extractall(path)

        self.df = pd.read_csv(os.path.join(path, 'features_30_sec.csv'))
        self.min_audio_len = self.df['length'].min()
        self.sample_rate = sample_rate

        # Filter malformed shit
        self.df = self.df[self.df['filename'] != 'jazz.00054.wav']
        self.resampler = torchaudio.transforms.Resample(22050, sample_rate)

        self.sr = sample_rate
        self.target_frames = self.sr * duration

        self.crop = ta.RandomResizedCrop(self.target_frames)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index: int):
        genre = self.df.iloc[index]['label']
        filename = self.df.iloc[index]['filename']
        path = os.path.join(self.path, 'genres_original', genre, filename)
        waveform, sr = torchaudio.load(path, normalize=True)

        if sr != self.sr:
            waveform = self.resampler(waveform)

        if waveform.size(1) < self.target_frames:
            pad_len = self.target_frames - waveform.size(1)
            waveform = F.pad(waveform, (0, pad_len))
        else:
            waveform = self.crop(waveform)

        return waveform

In [8]:
class MusicDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "./",
        batch_size: int = 1,
        num_workers: int = 1,
        target_sr: int = 16000,
        clip_duration: float = 10.0,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.target_sr = target_sr
        self.clip_duration = clip_duration

    def setup(self, stage: Optional[str] = None):
        self.train_dataset = GTZANAudioDataset(
            self.data_dir,
            16000,
            10
        )

    def _collate_fn(self, batch):
        clips = []
        for clip in batch:
            clips.append(clip)
        return torch.stack(clips)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=self._collate_fn
        )



### Train

In [9]:
dm = MusicDataModule("/home/benjx/cs_wsl/school/y4/cse253/LDMG/Data", batch_size=1, num_workers=1)
dm.setup()

In [15]:
vae = AudioVAE(1)

trainer = L.Trainer()
trainer.fit(vae, datamodule=dm)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/benjx/anaconda3/envs/ldmg/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name | Type      | Params | Mode 
-------------------------------------------
0 | vae  | VAE       | 10.1 M | train
1 | loss | ELBO_Loss | 0      | train
-------------------------------------------
10.1 M    Trainable params
0         Non-trainable params
10.1 M    Total params
40.490    Total estimated model params size (MB)
80        Modules in train mode
0         Modules in eval mode
/home/benjx/anaconda3/envs/ldmg/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The '

Epoch 0:   0%|          | 0/999 [00:00<?, ?it/s] 

: 

In [None]:
from IPython.display import Audio, display

def show_wav(waveform, sample_rate):
    display(Audio(waveform, rate=sample_rate))

In [None]:
tmp = dm.train_dataset[0].unsqueeze(0).to('cuda:0', dype=torch.float32)

In [None]:
vae = vae.to(device='cuda:0')
with torch.no_grad():
    out, _, _ = vae(tmp)
out = out[0][0].detach().cpu().numpy()
print(np.max(out))

# Example waveform and sample rate
sample_rate = 16000  # in Hz
waveform = out

show_wav(waveform, sample_rate)

In [None]:
out = tmp[0][0].detach().cpu().numpy()
print(np.max(out))

# Example waveform and sample rate
sample_rate = 16000  # in Hz
waveform = out

show_wav(waveform, sample_rate)