In [49]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, Subset
import os
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from helper_plot import hdr_plot_style
hdr_plot_style()
from torch.utils.tensorboard import SummaryWriter

In [50]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [51]:
class TechnoDataset(Dataset):

    def __init__(self,
                 dat_location,
                 size=2**15) -> None:
        super().__init__()

        self.samples = np.memmap(
            dat_location,
            dtype="float32",
            mode="r",
        )
        self.samples = self.samples[:size * (len(self.samples) // size)]
        self.samples = self.samples.reshape(-1, 1, size)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        return torch.from_numpy(np.copy(self.samples[index])).float()

dataset = TechnoDataset("./data/TECHNO/techno_resampled.dat")

In [52]:
class AE(nn.Module):
    def __init__(self, encoder, decoder, encoding_dim):
        super(AE, self).__init__()
        self.encoding_dims = encoding_dim
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [53]:
class VAE(AE):
    
    def __init__(self, encoder, decoder, encoding_dims, latent_dims):
        super(VAE, self).__init__(encoder, decoder, encoding_dims)
        self.latent_dims = latent_dims
        self.mu = nn.Sequential(nn.Linear(self.encoding_dims, self.latent_dims))
        self.sigma = nn.Sequential(nn.Linear(self.encoding_dims, self.latent_dims), nn.Softplus())
        
    def encode(self, x):
        
        x = self.encoder(x)
        mu = self.mu(x)
        sigma = self.sigma(x)
        
        return mu, sigma
    
    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        # Encode the inputs
        z_params = self.encode(x)
        # Obtain latent samples and latent loss
        z_tilde, kl_div = self.latent(x, z_params)
        # Decode the samples
        x_tilde = self.decode(z_tilde)
        return x_tilde.reshape(-1, 1, 28, 28), kl_div
    
    def latent(self, x, z_params):
        
        mu, sigma = z_params

        var = sigma * sigma
        log_var = torch.log(var)

        z = torch.rand_like(mu) * sigma + mu
        kl_div = torch.sum(mu * mu + var - log_var - 1)

        return z, kl_div

In [54]:
class Reshape(torch.nn.Module):
    def __init__(self, outer_shape):
        super(Reshape, self).__init__()
        self.outer_shape = outer_shape
    def forward(self, x):
        return x.view(x.size(0), *self.outer_shape)

In [55]:
def construct_encoder_decoder(nin=1, n_latent = 16, n_hidden = 64, n_params = 0, n_classes = 1):
    # Encoder network
    encoder = nn.Sequential(
          nn.Conv1d(nin, n_hidden, 4,  padding=1), nn.ReLU(),
          nn.Conv1d(n_hidden, 2*n_hidden, 4,  padding=1), nn.ReLU(),
          nn.Flatten(),
          nn.Linear(7*7*2*n_hidden, 1024),nn.ReLU(),
          nn.Linear(1024, n_latent)#, nn.ReLU(),
        )

    # Decoder network
    decoder = nn.Sequential(
          nn.Linear(n_latent,1024), nn.ReLU(),
          nn.Linear(1024, 7*7*2*n_hidden), nn.ReLU(),
          Reshape((2*n_hidden,7,7,)),
          nn.ConvTranspose1d(2*n_hidden, n_hidden, 4,  padding=1), nn.ReLU(),
          nn.ConvTranspose1d(n_hidden, nin*n_classes, 4,  padding=1), nn.Sigmoid()
        )
    return encoder, decoder

In [61]:
#o33552384 - (7*7*2*n_hidden)

33502208

In [56]:
# Reconstruction criterion
recons_criterion = torch.nn.BCELoss(reduction='sum')

def compute_loss(model, x):
    model = model#.to(device)    
    x = x#.to(device)    
    beta = 0.1
    y_pred,kl_div = model(x)
    recons_loss = recons_criterion(y_pred, x)
    full_loss = recons_loss + kl_div*beta
    #print(kl_div,recons_loss,full_loss)
    
    return full_loss,kl_div,recons_loss

def train_step(model, x, optimizer):
    model = model#.to(device)
    x = x#.to(device)
    # Compute the loss.
    loss,kl_div,recons_loss = compute_loss(model, x)
    # Before the backward pass, zero all of the network gradients
    optimizer.zero_grad()
    # Backward pass: compute gradient of the loss with respect to parameters
    loss.backward()
    # Calling the step function to update the parameters
    optimizer.step()
    return loss,kl_div,recons_loss

In [57]:
dataset_dir = './data'
valid_ratio = 0.999999
# Load the dataset for the training/validation sets
train_valid_dataset =  dataset
# Split it into training and validation sets
nb_train = int((1.0 - valid_ratio) * len(train_valid_dataset) +1)
nb_valid =  int(valid_ratio * len(train_valid_dataset))

train_dataset, valid_dataset = torch.utils.data.dataset.random_split(train_valid_dataset, [nb_train, nb_valid])


# Prepare 
num_threads = 0     # Loading the dataset is using 4 CPU threads
batch_size  = 128   # Using minibatches of 128 samples
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_threads)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_threads)


In [58]:
# Using Bernoulli or Multinomial loss
num_classes = 1
# Number of hidden and latent
n_hidden = 512
n_latent = 2
# Compute input dimensionality
nin = 1
# Construct encoder and decoder
encoder, decoder = construct_encoder_decoder(nin, n_hidden = n_hidden, n_latent = n_latent, n_classes = num_classes)
encoder = encoder#.to(device)
decoder = decoder#.to(device)
# Build the VAE model
model = VAE(encoder, decoder, n_hidden, n_latent)#.to(device)
# Construct the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [59]:


def train_vae(model, epochs):
    
    for epoch in range(1, epochs + 1):
        full_loss = torch.Tensor([0])#.to(device)
        kl_div = torch.Tensor([0])#.to(device)
        recons_loss = torch.Tensor([0])#.to(device)
        # Forward pass: compute predicted y by passing x to the model.
        for i, x in enumerate(train_loader):
            #x=x#.to(device)
            #data_loader = data_loader#.to(device)
            #model=model#.to(device)
            
            full_loss_add,kl_div_add,recons_loss_add = train_step(model, x, optimizer)
            full_loss += full_loss_add
            kl_div += kl_div_add
            recons_loss += recons_loss_add
            

        print('Epoch: {}, Test set ELBO: {}, Kl : {}, recons :{}'.format(epoch, full_loss,kl_div,recons_loss))

In [60]:
epochs=50

train_vae(model,epochs)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x33552384 and 50176x1024)