In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distrib
import numpy as np
import lmdb
import torchaudio
import librosa
from udls.generated import AudioExample
import IPython.display as ipd
import matplotlib.pyplot as plt

In [2]:
class AE(nn.Module):
    def __init__(self, encoder, decoder, encoding_dim):
        super(AE, self).__init__()
        self.encoding_dims = encoding_dim
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [41]:
class VAE(AE):
    
    def __init__(self, encoder, decoder, encoding_dims, latent_dims):
        super(VAE, self).__init__(encoder, decoder, encoding_dims)
        self.latent_dims = latent_dims
        self.mu = nn.Sequential(nn.Linear(self.encoding_dims, self.latent_dims))
        self.sigma = nn.Sequential(nn.Linear(self.encoding_dims, self.latent_dims), nn.Softplus())
        
    def encode(self, x):
        
        x_encoded = self.encoder(x)

        mu = self.mu(x_encoded)
        sigma = self.sigma(x_encoded)

        return (mu, sigma), x_encoded.shape
    
    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        # Encode the inputs
        z_params, x_encoded_shape = self.encode(x)
        print(f"{x_encoded_shape=}")
        # Obtain latent samples and latent loss
        mu, _ = z_params
        print(mu.shape)
        z_tilde, kl_div = self.latent(z_params)
        # Decode the samples
        z_tilde = z_tilde.reshape(x_encoded_shape)
        x_tilde = self.decode(z_tilde)
        return x_tilde.reshape(x.shape), kl_div
    
    def latent(self, z_params):
        
        normal = distrib.Normal(loc=0., scale=1.)
        mu, sigma = z_params
        kl_div = torch.sum(1 + torch.log(sigma**2) - mu**2 - 2*(sigma**2))/2
        z = mu + sigma * normal.sample(sigma.shape)
    
        return z, kl_div

In [55]:
def construct_encoder_decoder(nin, n_latent = 16, n_hidden = 512, n_classes = 1):
    # Encoder network
    encoder = nn.Sequential(
        nn.Conv1d(128, 64,7,2), nn.LeakyReLU(), nn.BatchNorm1d(64),
        nn.Conv1d(64, 32,7,2), nn.LeakyReLU(), nn.BatchNorm1d(32),
        nn.Conv1d(32, 16,7,4), nn.LeakyReLU(), nn.BatchNorm1d(16),
        nn.Flatten()


    )
    # Decoder network
    decoder = nn.Sequential(
        nn.ConvTranspose1d(16,32,7,4), nn.LeakyReLU(), nn.BatchNorm1d(32),
        nn.ConvTranspose1d(32,64,7,4), nn.LeakyReLU(), nn.BatchNorm1d(64),
        nn.ConvTranspose1d(64,128,7,4), nn.LeakyReLU(), nn.BatchNorm1d(128),
    )
    return encoder, decoder

In [56]:
test = torch.ones((64, 128, 1024))

In [62]:
# Using Bernoulli or Multinomial loss
num_classes = 1
# Number of hidden and latent
n_hidden = 992
n_latent = 16
# Compute input dimensionality
nin = test.shape[1] * test.shape[2]
# Construct encoder and decoder
encoder, decoder = construct_encoder_decoder(nin, n_hidden = n_hidden, n_latent = n_latent, n_classes = num_classes)
# Build the VAE model
model = VAE(encoder, decoder, n_hidden, n_latent)
# Construct the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [66]:
model(test)

x_encoded_shape=torch.Size([64, 992])
torch.Size([64, 16])


RuntimeError: shape '[64, 992]' is invalid for input of size 1024

In [67]:
epochs = 1
for epoch in range(1, epochs + 1):
    full_loss = torch.Tensor([0])
    # Forward pass: compute predicted y by passing x to the model.
    for i, (x, _) in enumerate(train_loader):
        full_loss += train_step(model, x, optimizer, beta = betas[epoch])
    #for i, (x, _) in enumerate(valid_loader):
    #    train_step(model, x, optimizer)
    if (epoch % 10 == 0):
        print('Epoch: {}, Test set ELBO: {}'.format(epoch, full_loss))

NameError: name 'train_loader' is not defined