In [1]:
import torch
import torch.nn as nn
from flows import PlanarFlow
from torch.distributions.normal import Normal
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import torch.optim as optim

In [None]:
class FlowVAE(nn.Module):
    def __init__(self, input_size, dim_h, dim_z, n_layer, flows):
        """
        normalizing flow model
        :param input_size: shape of the image
        :param dim_h: dimension of hidden states
        :param dim_z: dimension of latent variable
        :param n_layer: number of encoder and decoder layers
        :param flows: Flows to transform output of base encoder
        """
        super().__init__()
        self.dim_z = dim_z
        self.prior = Normal(0., 1.)
        self.encoder = nn.Sequential(nn.Linear(input_size[0]*input_size[1], 512), nn.ReLU(True), nn.Linear(512, 256), nn.ReLU(True))
        self.mu = nn.Linear(256, dim_z)
        self.var = nn.Linear(256, dim_z)

        self.decoder = nn.Sequential(nn.Linear(dim_z, 256), nn.ReLU(True), nn.Linear(256, 512), nn.ReLU(True), nn.Linear(512, input_size[0]*input_size[1]))
        self.flows = nn.ModuleList(flows)

    def forward(self, x):
        """
        Takes data batch, samples num_samples for each data point from base distribution
        :param x: data
        :return: generated image and kl divergence
        """
        x = self.encoder(x)
        mu = self.mu(x)
        log_var = self.mu(x)

        # reparameterization
        sigma = torch.exp(0.5 * log_var)
        z = mu + torch.randn_like(mu) * sigma
        q = Normal(mu, torch.exp((0.5 * log_var)))
        logq = q.log_prob(z)

        logp = -0.5 * torch.sum(2 * torch.log(sigma) + np.log(2 * np.pi) + ((z - mu) / sigma) ** 2, dim=1)
        for flow in self.flows:
            z, logp = flow(z, logp)

        kl = - torch.sum(logq, dim=-1) + torch.sum(self.prior.log_prob(z), dim=-1) - logp.view(-1)

        # likelihood
        output = self.decoder(z)

        return output, kl

    def sample_img(self):
        z = torch.randn(1, self.dim_z)
        return self.decoder(z)

In [None]:
train_loader = None
num_epoch = 0
dataset = ''
def train(model, data_name):
    if data_name == 'mnist':
        recon_loss = nn.BCELoss()
    elif data_name == 'cifar10':
        recon_loss = nn.MSELoss()

    model.train()
    optimizer = optim.Adam(model.parameters(), lr = 0.001)
    t = tqdm(enumerate(train_loader), total=len(train_loader))
    for epoch in range(num_epoch):
        epoch_loss = []
        for x in t:
            optimizer.zero_grad()
            recon, kl = model(x)
            loss = recon_loss(recon, x) + kl
            loss.backward()
            epoch_loss.append(loss.item())
            t.set_description("Loss %f" % loss)
            optimizer.step()
        print('Epoch: {} Loss: {:.4f}'.format(epoch, epoch_loss))

def test():
    pass