# 1. Setup

In [65]:
import sys
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import Trainer
from torch.optim import Adam
from pytorch_lightning.core.lightning import LightningModule
from torch.utils.data import Dataset, DataLoader

# 2. Load Dataset

In [66]:
wave_size = 256
latent_dims = 16

In [67]:
# Setting up the wav paths
wav_dir = "./WAV/Various-ups-downs/"
wav_file_name = "VARIOUS_.WAV"
wav_file = os.path.join(wav_dir, wav_file_name)

# Setup tensor for waves
waves = torch.empty(size=(wave_size,64))

# Load the wav file
wavebank, sample_rate = torchaudio.load(wav_file)

# Set the waves' tensor indexer
num_waves = 0

# Normalizes data per for each wave
def normalize(wave):
    return (wave - wave.min()) / (wave.max() - wave.min())

# Load the wav dataset into waves tensor
for i in range(64):
    waveform = wavebank[0, i*wave_size: i*wave_size+wave_size]
    if waveform.max() != waveform.min():
        waveform = normalize(waveform)
        waves[:,num_waves] = waveform
        num_waves += 1

# 3. Variational Autoencoder Implementation

In [77]:
# Defining the convolutional layer parameters
capacity = 16
kernel = 4
stride = 2
padding = 1

# Number of output channels at each convolutional layer
conv1_out_channels = capacity
conv2_out_channels = 2 * capacity

# Finds the width of the convolution output
def conv1DOutWidth(in_w, kernel_w=kernel, stride_w=stride, padding_w=padding):
    new_w = (in_w + 2*padding_w - kernel_w)/(stride_w) + 1
    return int(new_w)

# Convolution output widths
conv1_out_w = conv1DOutWidth(in_w=wave_size)
conv2_out_w = conv1DOutWidth(in_w=conv1_out_w)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = capacity
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv1d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1)
        self.fc_mu = nn.Linear(in_features=c*2*conv2_out_w, out_features=latent_dims)
        self.fc_logvar = nn.Linear(in_features=c*2*conv2_out_w, out_features=latent_dims)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*conv2_out_w)
        self.conv2 = nn.ConvTranspose1d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose1d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
            
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), conv2_out_channels, conv2_out_w)
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
        return x

class LitVariationalAutoencoder(LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.loss = []
        self.recon_loss = []
        self.learning_rate = 1e-3
    
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon, latent_mu, latent_logvar
    
    def latent_sample(self, mu, logvar):
        if self.training:
            # the reparameterization trick
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def training_step(self, wave_batch, batch_idx):
        wave_batch = wave_batch.unsqueeze(1)
        wave_batch_recon, latent_mu, latent_logvar = self(wave_batch)
        loss = vae_loss(wave_batch_recon, wave_batch, latent_mu, latent_logvar)
        self.log('training loss', loss, on_step=True, on_epoch=True, logger=True)
        self.loss.append(loss)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.learning_rate, weight_decay=1e-5)        

def vae_loss(recon_x, x, mu, logvar):
    variational_beta = 0.1
    recon_loss = F.binary_cross_entropy(recon_x.view(-1, wave_size), x.view(-1, wave_size), reduction='mean')
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Total variation loss.
    tv_loss = torch.sum(torch.pow(recon_x[:,:,:-1] - recon_x[:,:,1:], 2))
    return recon_loss + variational_beta * kldivergence + 100.01 * tv_loss
    

# 4. Set up dataset for training

In [83]:
class UnlabeledTensorDataset(Dataset):
    """Dataset wrapping unlabeled data tensors (autoencoders do not need labels).

    Each sample will be retrieved by indexing tensors along the second
    dimension.

    Arguments:
        data_tensor (Tensor): contains sample data.
    """
    def __init__(self, data_tensor, num_waves):
        self.data_tensor = data_tensor
        self.samples = list(range(0, num_waves))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        return self.data_tensor[:,index]

train_dataset = UnlabeledTensorDataset(waves[:,0:num_waves], num_waves)

batch_size = 1
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [84]:
vae = LitVariationalAutoencoder()
trainer = Trainer(accelerator='gpu', devices=1, max_epochs=50)
trainer.fit(vae, train_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 67.7 K
1 | decoder | Decoder | 36.9 K
------------------------------------
104 K     Trainable params
0         Non-trainable params
104 K     Total params
0.419     Total estimated model params size (MB)


Epoch 49: 100%|██████████| 64/64 [00:00<00:00, 84.42it/s, loss=1.41, v_num=13]  

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 64/64 [00:00<00:00, 83.18it/s, loss=1.41, v_num=13]


# 5. Export as ONNX

In [86]:
import torch.onnx

dummy_input = torch.randn(1,latent_dims)
vae.decoder.eval()
torch.onnx.export(
    vae.decoder,
    dummy_input,
    "./test_model.onnx",
    verbose=True,
    input_names=['modelInput'],
    output_names=['modelOutput'],
    opset_version=11,
    export_params=True,
)

Exported graph: graph(%modelInput : Float(1, 16, strides=[16, 1], requires_grad=0, device=cpu),
      %fc.weight : Float(2048, 16, strides=[16, 1], requires_grad=1, device=cpu),
      %fc.bias : Float(2048, strides=[1], requires_grad=1, device=cpu),
      %conv2.weight : Float(32, 16, 4, strides=[64, 4, 1], requires_grad=1, device=cpu),
      %conv2.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %conv1.weight : Float(16, 1, 4, strides=[4, 4, 1], requires_grad=1, device=cpu),
      %conv1.bias : Float(1, strides=[1], requires_grad=1, device=cpu)):
  %/fc/Gemm_output_0 : Float(1, 2048, strides=[2048, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/fc/Gemm"](%modelInput, %fc.weight, %fc.bias), scope: __main__.Decoder::/torch.nn.modules.linear.Linear::fc # /home/jeremy/miniconda3/lib/python3.9/site-packages/torch/nn/modules/linear.py:114:0
  %/Constant_output_0 : Long(3, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[val