## 1. Create loss function
## 2. Create forward function
## 3. Return MSE + KLD (Kullback-Leibler divergence loss)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class VAELoss(nn.Module):
    def __init__(self):
        super(VAELoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="sum")
        self.kld_loss = nn.KLDivLoss(reduction="batchmean")
    
    def forward(self, x_reconstructed, x):
        loss_MSE = self.mse_loss(x_reconstructed, x)
        loss_KLD = self.kld_loss(F.log_softmax(x_reconstructed, 0), F.softmax(x,0)).mean()

        return loss_MSE + loss_KLD


## 4. Use the reconstructed array to test loss function

In [2]:
x_reconstructed = np.array([[-5.4997e-01,  9.6568e-01, -5.8461e-01,  1.6792e-01,  9.1045e-02,
         -6.3715e-01, -2.3070e-01,  1.5938e-01,  1.2707e+00, -2.7932e-01,
          6.0861e-01, -4.5742e-01, -3.0209e-01, -4.8071e-01, -3.0515e-01,
         -2.5727e-01, -1.4367e-01, -1.0434e-01, -8.6332e-02,  4.5512e-09,
         -6.2267e-02,  8.7327e-02, -9.6045e-02,  2.2917e-04],
        [-1.2184e+00,  7.9449e-01,  3.5487e-01,  1.9401e-01,  3.7659e-01,
          4.9783e-01, -2.2033e-01, -3.5635e-01, -4.0869e-01, -1.9539e-01,
         -3.9908e-01, -3.1146e-01,  2.6880e-01, -4.6869e-01, -2.4191e-01,
         -2.2887e-01, -2.4918e-01, -8.9690e-02, -7.0067e-01,  4.5512e-09,
          1.0222e-01,  2.4068e-01, -4.5194e-01,  2.2917e-04],
        [ 6.8352e-01, -4.1366e-01,  3.3967e-01,  1.9611e-01,  7.5186e-01,
          4.9786e-01, -2.2223e-02, -3.2298e-01, -4.8018e-03, -2.3531e-01,
         -3.2660e-01,  4.3518e-02, -1.9960e-01, -2.2237e-01, -3.4907e-01,
         -1.5650e-01, -1.8371e-01,  1.0963e+00, -3.1199e-01,  4.5512e-09,
          2.7832e-01, -3.3218e-01, -3.5350e-01,  2.2917e-04],
        [ 1.4130e-01,  6.3954e-01, -1.2712e-01, -1.3359e-01, -1.0345e-01,
         -3.4473e-01, -1.5453e-01,  1.2371e+00, -5.1910e-01, -1.3152e-01,
          3.3541e-01,  2.8345e-01, -2.3065e-01,  5.3266e-02, -2.9733e-01,
         -5.5412e-02,  3.5350e-01, -1.6218e-01, -4.1126e-01,  4.5512e-09,
         -5.9868e-02, -3.9497e-01,  2.0334e-01,  2.2917e-04],
        [ 6.4592e-01, -1.0789e-01, -4.0128e-01, -1.2244e+00, -2.1323e-01,
         -4.1193e-01, -2.3616e-01,  1.0052e+00, -1.7424e-01, -2.7838e-01,
         -3.7816e-01,  4.2149e-02, -6.6972e-02,  1.0228e+00, -3.8965e-01,
          1.0374e+00,  1.2995e-01,  1.7991e-01, -8.2993e-02,  4.5512e-09,
         -2.3409e-01, -2.1431e-01,  3.9755e-01,  2.2917e-04]])

data = np.array([[-2.4966e-01,  2.2019e+00,  1.2155e+00,  2.4365e-01, -4.8599e-01,
         -8.5299e-01, -4.0511e-01, -3.0383e-01, -4.8721e-01, -2.4032e-01,
          2.3985e-01,  2.2014e-01,  8.5676e-01, -8.2279e-01, -2.3188e-01,
         -1.6660e-01, -1.1348e-01, -1.4504e-01,  1.3448e+00,  0.0000e+00,
         -1.6408e-01,  3.1328e-01,  1.0924e+00,  0.0000e+00],
        [-1.6242e+00,  1.6726e-01,  1.1472e+00,  2.4365e-01,  1.8850e+00,
          4.4536e-01, -4.0511e-01, -3.0383e-01, -4.8721e-01, -2.4032e-01,
          7.5943e-01,  9.7630e-01,  3.9133e-01, -6.2471e-02, -2.2974e-01,
         -1.6411e-01, -1.2436e-01, -1.4504e-01, -1.3013e+00,  0.0000e+00,
         -1.7758e-01, -6.1999e-01,  9.1251e-01,  0.0000e+00],
        [ 4.1433e-01, -5.1314e-01,  9.6484e-01,  2.4365e-01, -8.3350e-02,
          6.3540e-01, -4.0511e-01, -3.0383e-01,  2.4122e+00, -2.4032e-01,
          5.9482e-01, -2.2929e+00, -1.9047e-01,  3.4309e+00, -2.3104e-01,
         -1.7033e-01, -1.8304e-01, -1.4504e-01, -6.2815e-01,  0.0000e+00,
         -1.6071e-01, -6.1378e-01, -4.1287e-01,  0.0000e+00],
        [-1.3773e+00,  8.0657e-02, -5.3026e-01,  2.4365e-01, -8.8385e-01,
         -4.7867e-01, -4.0511e-01, -3.0383e-01, -1.6506e-01, -2.4032e-01,
         -3.0710e-01,  5.9555e-01, -4.2319e-01,  3.2629e+00, -2.2266e-01,
         -1.6121e-01, -1.8514e-01, -1.4504e-01, -1.5798e+00,  0.0000e+00,
         -1.4834e-01, -2.0111e-01,  1.5353e+00,  0.0000e+00],
        [ 4.5911e-01,  5.3135e-01,  5.5121e-01,  2.4365e-01,  2.0276e+00,
         -4.1786e-01, -4.0511e-01, -3.0383e-01, -3.2613e-01, -2.4032e-01,
          1.7496e+00,  5.0435e-01,  2.6022e+00,  1.7045e+00, -2.3405e-01,
          5.7271e+00, -2.2456e-01, -1.4504e-01, -1.6262e+00,  0.0000e+00,
         -1.8995e-01,  1.0040e+00,  4.3751e-01,  0.0000e+00]])

In [3]:
loss_function = VAELoss()
x_reconstructed = torch.from_numpy(x_reconstructed)
data = torch.from_numpy(data)
loss = loss_function(x_reconstructed, data)
loss.item() 

123.27132861809629