In [None]:
from argparse import ArgumentParser
from math import prod
from typing import Tuple

import torch
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F

import torch.nn.functional as F
import matplotlib.pyplot as plt

from pytorch_lightning.callbacks import ModelCheckpoint

from torchvision import datasets, transforms

In [None]:
class MnistVAE(pl.LightningModule):
    # =========== MAGIC METHODS =============
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        # Load instances for each type of DAFX
        self.hidden_dim_enc = prod(self.hparams.hidden_dim)
        self.hidden_dim_dec = self.hparams.hidden_dim

        self._build_model()

    # =========== PRIVATE METHODS =============
    def _build_model(self):
        self._build_encoder()
        self._build_decoder()

    def _build_encoder(self):
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=self.hparams.num_channels,
                      out_channels=8,
                      kernel_size=self.hparams.conv_kernel,
                      padding=self.hparams.conv_padding,
                      stride=self.hparams.conv_stride
                      ),
            nn.ReLU(),
            nn.BatchNorm2d(8)
        )

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(in_channels=8,
                      out_channels=16,
                      kernel_size=self.hparams.conv_kernel,
                      padding=self.hparams.conv_padding,
                      stride=self.hparams.conv_stride
                      ),
            nn.ReLU(),
            nn.BatchNorm2d(16),
        )

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=32,
                      kernel_size=self.hparams.conv_kernel,
                      padding=self.hparams.conv_padding,
                      stride=self.hparams.conv_stride
                      ),
            nn.ReLU(),
            nn.BatchNorm2d(32),
        )

        self.mu = nn.Linear(self.hidden_dim_enc, self.hparams.latent_dim)
        self.log_var = nn.Linear(self.hidden_dim_enc, self.hparams.latent_dim)

    def _build_decoder(self):
        self.dec_hidden = nn.Sequential(
            nn.Linear(in_features=self.hparams.latent_dim, out_features=self.hidden_dim_enc),
            nn.ReLU())

        self.dec_conv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32,
                               out_channels=16,
                               kernel_size=self.hparams.conv_kernel,
                               padding=self.hparams.conv_padding,
                               stride=self.hparams.conv_stride
                               ),
            nn.ReLU(),
            nn.BatchNorm2d(16)
        )

        self.dec_conv2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=16,
                               out_channels=8,
                               kernel_size=self.hparams.conv_kernel,
                               padding=self.hparams.conv_padding,
                               stride=self.hparams.conv_stride
                               ),
            nn.ReLU(),
            nn.BatchNorm2d(8),
        )

        self.dec_conv3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=8,
                               out_channels=self.hparams.num_channels,
                               kernel_size=self.hparams.conv_kernel,
                               padding=self.hparams.conv_padding,
                               stride=self.hparams.conv_stride
                               ),
        )

    @staticmethod
    def _calculate_kl_loss(mu, log_var):
        # calculate KL divergence
        kld_batch = -0.5 * torch.sum(1 + log_var - torch.square(mu) - torch.exp(log_var), dim=1)
        kld = torch.mean(kld_batch)

        return kld

    def _calculate_reconstruction_loss(self, x, x_hat):
        if self.hparams.recon_loss.lower() == "mse":
            return F.mse_loss(x, x_hat, reduction="mean")
        elif self.hparams.recon_loss.lower() == "l1":
            return F.l1_loss(x, x_hat, reduction="mean")
        elif self.hparams.recon_loss.lower() == "bce":
            return F.binary_cross_entropy(x, x_hat, reduction="mean")
        else:
            raise NotImplementedError

    def encode(self, x):
        x = self.enc_conv1(x)
        x = self.enc_conv2(x)
        x = self.enc_conv3(x)

        x = x.view(-1, self.hidden_dim_enc)

        mu = self.mu(x)
        log_var = self.log_var(x)

        return mu, log_var

    def decode(self, z):
        x = self.dec_hidden(z)

        x = x.view(-1, *self.hidden_dim_dec)

        x = self.dec_conv1(x)
        x = self.dec_conv2(x)
        x = self.dec_conv3(x)

        return x

    @staticmethod
    def reparameterise(mu, log_var):
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterise(mu, log_var)
        out = self.decode(z)

        return out, mu, log_var

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    def common_paired_step(
            self,
            batch: Tuple,
            batch_idx: int,
            train: bool = False,
    ):
        # Get spectrograms
        x, _ = batch

        # Get reconstruction as well as mu, var
        x_hat, x_mu, x_log_var = self(x)

        # Calculate recon losses for clean/effected signals
        recon_loss = self._calculate_reconstruction_loss(x, x_hat)
        kld = self._calculate_kl_loss(x_mu, x_log_var)

        # Total loss is additive
        loss = recon_loss + (self.hparams.vae_beta * kld)

        # log the losses
        self.log(("train" if train else "val") + "_loss/loss", loss)
        self.log(("train" if train else "val") + "_loss/reconstruction_loss", recon_loss)
        self.log(("train" if train else "val") + "_loss/kl_divergence", kld)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.common_paired_step(
            batch,
            batch_idx,
            train=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.common_paired_step(
            batch,
            batch_idx,
            train=False,
        )

        return loss

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # -------- Training -----------
        parser.add_argument("--batch_size", type=int, default=8)
        parser.add_argument("--lr", type=float, default=1e-5)
        parser.add_argument("--recon_loss", type=str, default="mse")
        parser.add_argument("--vae_beta", type=float, default=1.)

        # --------- DAFX ------------
        parser.add_argument("--dafx_file", type=str, default="src/dafx/mda.vst3")
        parser.add_argument("--dafx_names", nargs="*")
        parser.add_argument("--dafx_param_names", nargs="*", default=None)

        # --------- VAE -------------
        parser.add_argument("--num_channels", type=int, default=1)
        parser.add_argument("--hidden_dim", nargs="*", default=(32, 9, 257))
        parser.add_argument("--latent_dim", type=int, default=1024)
        parser.add_argument("--conv_kernel", type=int, default=3)
        parser.add_argument("--conv_padding", type=int, default=1)
        parser.add_argument("--conv_stride", type=int, default=2)

        # ------- Dataset  -----------
        parser.add_argument("--audio_dir", type=str, default="src/audio")
        parser.add_argument("--ext", type=str, default="wav")
        parser.add_argument("--input_dirs", nargs="+", default=['musdb18_24000', 'vctk_24000'])
        parser.add_argument("--buffer_reload_rate", type=int, default=1000)
        parser.add_argument("--buffer_size_gb", type=float, default=1.0)
        parser.add_argument("--sample_rate", type=int, default=24_000)
        parser.add_argument("--dsp_sample_rate", type=int, default=24_000)
        parser.add_argument("--shuffle", type=bool, default=True)
        parser.add_argument("--random_effect_threshold", type=float, default=0.75)
        parser.add_argument("--train_length", type=int, default=131_072)
        parser.add_argument("--train_frac", type=float, default=0.9)
        parser.add_argument("--effect_audio", type=bool, default=True)
        parser.add_argument("--half", type=bool, default=False)
        parser.add_argument("--train_examples_per_epoch", type=int, default=10_000)
        parser.add_argument("--val_length", type=int, default=131_072)
        parser.add_argument("--val_examples_per_epoch", type=int, default=100)
        parser.add_argument("--num_workers", type=int, default=4)
        parser.add_argument("--dummy_setting", type=bool, default=False)

        return parser


In [None]:
# prerequisites


bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False, num_workers=4)

torch.set_float32_matmul_precision('medium')

In [None]:
args = {
    "conv_kernel": 3,
    "conv_padding": 1,
    "conv_stride": 1,
    "latent_dim": 256,
    "hidden_dim": [32,28,28],
    "num_channels": 1,
    "dafx_names": [],
    "lr": 1e-4,
    "recon_loss": "mse",
    "vae_beta":1,
}

In [None]:
model = MnistVAE(**args)

In [None]:
checkpoint_callback = ModelCheckpoint(monitor="val_loss/loss", mode="min")

trainer = pl.Trainer(accelerator="gpu", max_epochs=100, callbacks=[checkpoint_callback])

In [None]:
trainer.fit(model, train_loader, test_loader)

In [None]:
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
NUM_EXAMPLES=5

fig, axes = plt.subplots(NUM_EXAMPLES, 2, figsize=(8,5))

for i, spc in enumerate(test_loader):
    if i >= NUM_EXAMPLES:
        break

    x, _ = spc
    x_hat, _, _ = model(x)

    mse = F.mse_loss(x, x_hat)

    ax1 = axes[i, 0].imshow(x.squeeze().numpy())
    ax2 = axes[i, 1].imshow(x_hat.detach().squeeze().numpy())

    axes[i, 0].set_title(f"Original {i+1}", fontsize=10)
    axes[i, 1].set_title(f"Reconstruction {i+1} (MSE: {mse.item():.4f})", fontsize=10)

    axes[i, 0].set_xticks([])
    axes[i, 0].set_yticks([])

    axes[i, 1].set_xticks([])
    axes[i, 1].set_yticks([])

plt.tight_layout()
plt.savefig("./figures/mnist_reconstruction.svg")