In [3]:
import os
import torch
import numpy as np
from torch.optim import Adam
from torch import nn, utils, Tensor
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset,DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import lightning as L

In [4]:
np.random.seed(0)
dim = 50
actual_dim = 25
latent_dim = 50

In [5]:
class VectorDataset(Dataset):
    def __init__(self, n, idxs, dim=dim, actual_dim=actual_dim):
        self.x = torch.ones(n,dim)*10
        self.x[:,idxs] = 5#torch.randn(n,actual_dim)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx]

In [9]:
idxs = np.random.choice(dim,actual_dim)
train_ds,test_ds = VectorDataset(10000,idxs), VectorDataset(100,idxs)

In [11]:
class Encoder(nn.Module):
    
    def __init__(self,input_dim=dim,latent_dim=latent_dim):
        super(Encoder, self).__init__()
        self.l1 = nn.Linear(50, 25)
        self.l2 = nn.Linear(25, 50)
        self.latent_mu = nn.Linear(25,50)
        self.latent_log_sigma = nn.Linear(25,50)
        self.act = nn.GELU()
    def forward(self,x):
        x = self.act(self.l1(x))
        return self.l2(x),self.latent_mu(x),self.latent_log_sigma(x)

e = Encoder()
e(train_ds[:100])[0].shape

torch.Size([100, 50])

In [13]:
class Decoder(nn.Module):
    def __init__(self,latent_dim=latent_dim,output_dim=dim):
        super().__init__()
        self.l1 = nn.Linear(50,50)
        self.l2 = nn.Linear(50,50)
        self.mu = nn.Linear(50,output_dim)
        self.log_sigma = nn.Linear(50,output_dim)
        self.act1 = nn.GELU()

    def forward(self,x):
        x = self.act1(self.l1(x))
        return self.l2(x), self.mu(x), self.log_sigma(x)

e = Encoder()
zs = e(train_ds[:100])[0]
print(zs.shape)
d = Decoder()
d(zs[0])[0].shape

torch.Size([100, 50])


torch.Size([50])

In [84]:
class VAE(L.LightningModule):
    def __init__(self,input_dim=dim,latent_dim=latent_dim):
        super().__init__()
        self.z_dim = latent_dim
        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(latent_dim, input_dim)

    def sample(self,z_mu,z_log_sigma):
        epsilon = torch.randn(self.z_dim)
        return z_mu + torch.exp(0.5 * z_log_sigma) * epsilon
        
    def loss_fn_mse(self,x,x_recon):
        return  nn.MSELoss()(x,x_recon)
        
    def loss_fn_mse_kl(self,x,x_recon,z_mu,z_log_sigma):
        kl = -0.5*(1. + z_log_sigma  - torch.square(z_mu) - torch.exp(z_log_sigma))
        mse = nn.MSELoss()
        mse_loss = mse(x,x_recon)
        kl_loss = torch.mean(torch.sum(kl,axis=1))
        loss = mse_loss+kl_loss
        return  loss,mse_loss,kl_loss

    def loss_fn_gaussian_vae(self,x,x_mu,x_log_sigma,z_mu,z_log_sigma):
        kl = -0.5*(1. + z_log_sigma  - torch.square(z_mu) - torch.exp(z_log_sigma))
        log_2_pi = torch.log(torch.sqrt(torch.Tensor([2*torch.pi])))
        neg_log_like = torch.mean(-torch.sum(-(x - x_mu)**2/(2*torch.exp(x_log_sigma)) - log_2_pi - 0.5*x_log_sigma,axis=1))
        kl = torch.mean(torch.sum(kl,axis=1))
        return  neg_log_like+kl,neg_log_like,kl 

    def forward(self,x):
        z,z_mu,z_log_sigma = self.encoder(x)
        sample_z = self.sample(z_mu,z_log_sigma)
        x_recon, x_mu, x_log_sigma = self.decoder(sample_z)
        return x,x_recon,x_mu,x_log_sigma,z_mu,z_log_sigma
        
    def training_step(self,batch,batch_idx):
        x,x_recon,x_mu,x_log_sigma,z_mu,z_log_sigma = self.forward(batch)
        # loss,mse,kl = self.loss_fn_mse_kl(x,x_recon,z_mu,z_log_sigma)
        # self.log("mse",mse)
        # self.log("kl",kl)
        # self.log("total loss", loss)
        # return loss
        
        loss,ll,kl = self.loss_fn_gaussian_vae(batch,x_mu,x_log_sigma,z_mu,z_log_sigma)
        self.log("ll",ll)
        self.log("kl",kl)
        self.log("total loss", loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer
        
v = VAE()
v(train_ds[:100])[0].shape

torch.Size([100, 50])

This loss doesnt seem to work at all even though its the one given in the paper for real valued outputs

In [97]:
vae = VAE()
trainer = L.Trainer(max_epochs=100)
batch_size = 64
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
trainer.fit(model=vae,train_dataloaders=train_dl)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 5.2 K 
1 | decoder | Decoder | 10.2 K
------------------------------------
15.4 K    Trainable params
0         Non-trainable params
15.4 K    Total params
0.061     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [98]:
vae.eval()
with torch.no_grad():
    xs = train_ds[:5]
    print(xs[0])
    print(vae(xs)[1][0])

tensor([ 5.,  5., 10.,  5., 10., 10.,  5., 10., 10.,  5., 10., 10.,  5.,  5.,
        10., 10., 10.,  5., 10.,  5., 10.,  5., 10.,  5.,  5.,  5., 10., 10.,
        10., 10., 10., 10., 10., 10., 10., 10.,  5.,  5.,  5.,  5., 10., 10.,
        10., 10.,  5., 10.,  5.,  5., 10., 10.])
tensor([-0.6757,  0.6260,  0.8456,  0.8918,  2.0560, -3.7391,  0.3053,  0.8659,
         1.5347,  2.6437, -0.6409,  2.0119, -0.9946, -0.9982, -0.2103, -1.5236,
         0.8755, -0.6488,  0.0203,  0.6154,  0.8579, -0.9848,  0.9171, -1.1888,
         0.5390, -0.3237,  0.3431,  0.5804, -1.7417, -0.8796, -1.1886,  0.6928,
        -1.0472, -1.5047, -0.0160, -0.3569,  1.6718, -1.0586,  0.7914, -0.4584,
        -0.7353, -0.3448, -1.2833,  0.5357,  1.0551,  1.3067, -0.4196, -1.1429,
        -2.1582, -0.5758])


In [27]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [47]:
%tensorboard --logdir lightning_logs

Reusing TensorBoard on port 6007 (pid 34356), started 0:03:43 ago. (Use '!kill 34356' to kill it.)