In [None]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 100
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
from Tars.distributions import NormalModel, BernoulliModel
from Tars.distributions.divergences import KullbackLeibler
from Tars.models import VAE

In [None]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [None]:
x_dim = 784
y_dim = 10
z_dim = 64


# inference model q(z|x,y)
class Inference(NormalModel):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x","y"], var=["z"])

        self.fc1 = nn.Linear(x_dim+y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x, y):
        h = F.relu(self.fc1(torch.cat([x, y], 1)))
        h = F.relu(self.fc2(h))        
        return self.fc31(h), F.softplus(self.fc32(h))

    
# generative model p(x|z,y)    
class Generator(BernoulliModel):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z","y"], var=["x"])

        self.fc1 = nn.Linear(z_dim+y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z, y):
        h = F.relu(self.fc1(torch.cat([z, y], 1)))
        h = F.relu(self.fc2(h))
        return F.sigmoid(self.fc3(h))

    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = NormalModel(loc=loc, scale=scale, var=["z"], dim=z_dim)

In [None]:
p = Generator()
q = Inference()

p.to(device)
q.to(device)

In [None]:
kl = KullbackLeibler(q, prior)
model = VAE(q, p, regularizer=kl, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [None]:
def train(epoch):
    train_loss = 0
    for batch_idx, (x_data, y_data) in enumerate(tqdm(train_loader)):
        x_data = x_data.to(device)
        y_data = torch.eye(10)[y_data].to(device)        
        lower_bound, loss = model.train({"x": x_data.view(-1, 784), "y": y_data})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [None]:
def test(epoch):
    test_loss = 0
    for i, (x_data, y_data) in enumerate(test_loader):
        x_data = x_data.to(device)
        y_data = torch.eye(10)[y_data].to(device)
        lower_bound, loss = model.test({"x": x_data.view(-1, 784), "y": y_data})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [None]:
def plot_reconstrunction(x_data, y_data):
    with torch.no_grad():
        z = q.sample({"x": x_data.view(-1, 784), "y": y_data})
        recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x_data, recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample, y_sample):
    with torch.no_grad():
        sample = p.sample_mean({"z": z_sample, "y": y_sample}).view(-1, 1, 28, 28).cpu()
        return sample
    
def plot_reconstrunction_changing_y(x_data, y_data):
    with torch.no_grad():
        z = q.sample({"x": x_data.view(-1, 784), "y": y_data})
        
        x_all = [x_data]
        for i in range(10):
            y_change = torch.zeros_like(y_data).to(device)
            y_change[:, i] = 1
            recon_batch = p.sample_mean({"z": z["z"], "y": y_change}).view(-1, 1, 28, 28)
            x_all.append(recon_batch)
    
        comparison = torch.cat(x_all).cpu()
        return comparison

In [None]:
writer = SummaryWriter()

plot_number = 1

z_sample = 0.5 * torch.randn(64, z_dim).to(device)
y_sample = torch.eye(10)[[plot_number]*64].to(device)

x_original, y_original = iter(test_loader).next()
x_original = x_original.to(device)
y_original = torch.eye(10)[y_original].to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(x_original[:8], y_original[:8])
    sample = plot_image_from_latent(z_sample, y_sample)
    recon_change_y = plot_reconstrunction_changing_y(x_original[:8], y_original[:8])

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_image('Image_from_latent', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)
    writer.add_image('Image_reconstrunction_change_y', recon_change_y, epoch)
    
writer.close()