In [1]:
import torch
import torch.nn.functional as F
from torch import nn

In [2]:
class VAE(nn.Module):
    def __init__(self, in_dims=128, out_dims=128, latent_dims=3, eps=1e-4):
        super(VAE, self).__init__()
        
        self.in_dims = in_dims
        self.out_dims = out_dims
        self.latent_dims = latent_dims
        self.eps = eps
        
        self.encoder = nn.Sequential(
            nn.Linear(self.in_dims, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
        )
        
        self.mu_layer = nn.Linear(100, self.latent_dims)
        self.sigma_layer = nn.Linear(100, self.latent_dims)
        
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dims, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, self.out_dims),
            nn.Tanh(),
        )
        
    def encode_q_z(self, x):
        h = self.encoder(x)
        mu = self.mu_layer(h)
        sigma = self.sigma_layer(h)
        sigma = F.softplus(sigma) + torch.FloatTensor([self.eps] * self.latent_dims)
        return mu, sigma
    
    def reparametrize(self, mu, sigma):
        std = logvar.mu(0.5).exp_()
        ksi = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            ksi = ksi.cuda()
        return ksi.mul(std).add_(mu)
    
    def decode_p_x(self, x):
        return self.decoder(x)
    
    def forward(self, x):
        mu, logvar = self.encode_q_z(x)
        z = reparametrize(mu, logvar)
        return decode_p_x(z), mu, logvar
    
    def loss_function(self, recon_x, x, mu, logvar):
        MSE = nn.MSELoss()
        # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
        KLD = torch.sum(KLD_element).mul_(-0.5)
        # KL divergence
#         print(MSE(recon_x, x))
#         print(KLD)
        return MSE(recon_x, x) + KLD

In [3]:
net = VAE()
print(net)
if torch.cuda.is_available():
    net = net.cuda()

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=128, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=100, bias=True)
    (3): ReLU()
  )
  (mu_layer): Linear(in_features=100, out_features=3, bias=True)
  (sigma_layer): Linear(in_features=100, out_features=3, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=3, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=128, bias=True)
    (5): Tanh()
  )
)


In [4]:
net.loss_function(torch.Tensor([0,0]),torch.Tensor([1,2]), torch.Tensor([1, 1.5, 2.0]), torch.Tensor([0.1,0.2,0.3]))

tensor(6.1632)