In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Normal, Bernoulli

In [2]:
class FullyConnected(nn.Sequential):
    """
    Fully connected multi-layer network with ELU activations.
    """
    def __init__(self, sizes, final_activation=True):
        layers = []
        for in_size, out_size in zip(sizes, sizes[1:]):
            layers.append(nn.Linear(in_size, out_size))
            layers.append(nn.ELU())
        if not final_activation:
            layers.pop(-1)
        super().__init__(*layers)

    def append(self, layer):
        assert isinstance(layer, nn.Module)
        self.add_module(str(len(self)), layer)

In [61]:
class BernNet(nn.Module):
    """
    Bernoulli network.
    """
    def __init__(self, sizes):
        super().__init__()
        self.fc = FullyConnected(sizes+[1])

    def forward(self, x):
        return self.fc(x) # logits

    def sample(self, logits):
        return Bernoulli(logits=logits).sample()

In [62]:
class GaussianNet(nn.Module):
    """
    Gaussian network.
    """
    def __init__(self, sizes):
        super().__init__()
        self.fc = FullyConnected(sizes+[2])

    def forward(self, x):
        mu, logvar = self.fc(x)
        return mu, logvar

    def sample(self, mu, logvar):
        return Normal(mu, logvar.exp()).sample()

In [None]:
class CEVAE(nn.module):
    def __init__(self, input_size, latent_size=20, hidden_size=[256,256]):
        super(CEVAE, self).__init__()
        # Decoder
        self.x1 = BernNet([input_size]+hidden_size+[latent_size])
        

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, z
    def sample(self, z):
        x_hat = self.decoder(z)
        return x_hat