In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import models
import numpy as np
import torch.nn.functional as F
import librosa
import numpy as np
from torchvision import models
import os
import soundfile as sf
import librosa
import matplotlib.pyplot as plt
from preprocess import *

In [None]:
spec,y,sr = audio_to_spectrogram('kangaroo.wav',duration=10)
s2a = spectrogram_to_audio(y)
save_audio_as_wav(s2a)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def GramMatrix(input):
    # A = batch
    # B = Number of feature maps
    # C,D = Dimension of feature maps
    a,b,c,d = input.size()

    features = input.view(a*b,c*d)

    G = torch.mm(features, features.t())

    #Normalize by dividing by total
    return G.div(a*b*c*d)

In [None]:
class VAE(nn.Module):
    def __init__(self,latent_dim):
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1,32,kernel_size=3,stride=2),
            nn.ReLU(),
            nn.Conv2d(32,64,kernel_size=3,stride=2),
            nn.ReLU(),
            nn.Conv2d(64,128,kernel_size=3,stride=2),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(128 * 16 * 16, latent_dim)
        self.fc_logvar = nn.Linear(128 * 16 * 16, latent_dim)
        
        self.decoder_fc = nn.Linear(latent_dim, 128 * 16 * 16)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128,64,kernel_size=3,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64,32,kernel_size=3,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32,16,kernel_size=3,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16,8,kernel_size=3,stride=2),
            nn.Sigmoid(),
        )
    
    def encode(self,x):
        x = self.encoder(x)
        mean, logvar = self.fc_mu(x), self.fc_logvar(x)
        return mean, logvar
    
    def reparameterize(self, mu, logvar):
        eps = torch.randn_like(mu).to(device)
        z = mu + logvar * eps
        return z
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_hat = self.decode(z)
        return x_hat, mean, logvar

In [None]:
def loss_fn(recon,original,mu,logvar):
    recon_loss = F.mse_loss(recon,original,reduction='sum')
    kl_div = -0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp())
    return recon_loss,kl_div

In [None]:
model = VAE(latent_dim=64).to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)