In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Normal, Bernoulli, Categorical
from Tars.losses import KullbackLeibler
from Tars.models import VAE

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
x_dim = 784
y_dim = 10
z_dim = 64


# inference model q(z|x,y)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x","y"], var=["z"])

        self.fc1 = nn.Linear(x_dim+y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x, y):
        h = F.relu(self.fc1(torch.cat([x, y], 1)))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# inference model q(z|x)
class InferenceX(Normal):
    def __init__(self):
        super(InferenceX, self).__init__(cond_var=["x"], var=["z"])

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# inference model q(z|y)
class InferenceY(Normal):
    def __init__(self):
        super(InferenceY, self).__init__(cond_var=["y"], var=["z"])

        self.fc1 = nn.Linear(y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, y):
        h = F.relu(self.fc1(y))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z)    
class GeneratorX(Bernoulli):
    def __init__(self):
        super(GeneratorX, self).__init__(cond_var=["z"], var=["x"])

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": F.sigmoid(self.fc3(h))}
    
# generative model p(y|z)    
class GeneratorY(Categorical):
    def __init__(self):
        super(GeneratorY, self).__init__(cond_var=["z"], var=["y"])

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, y_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": F.softmax(self.fc3(h), dim=1)}

    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)

In [5]:
p_x = GeneratorX().to(device)
p_y = GeneratorY().to(device)

q = Inference().to(device)
q_x = InferenceX().to(device)
q_y = InferenceY().to(device)

In [6]:
p = p_x * p_y
print(p.prob_text, p.prob_factorized_text)

p(y,x|z) p(y|z)p(x|z)


In [7]:
kl = KullbackLeibler(q, prior)
kl_x = KullbackLeibler(q, q_x)
kl_y = KullbackLeibler(q, q_y)

regularizer = kl + kl_x + kl_y

model = VAE(q, p, other_distributions=[q_x, q_y],
            regularizer=regularizer, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [8]:
def train(epoch):
    train_loss = 0
    for batch_idx, (x_data, y_data) in enumerate(tqdm(train_loader)):
        x_data = x_data.to(device)
        y_data = torch.eye(10)[y_data].to(device)        
        lower_bound, loss = model.train({"x": x_data.view(-1, 784), "y": y_data})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [9]:
def test(epoch):
    test_loss = 0
    for i, (x_data, y_data) in enumerate(test_loader):
        x_data = x_data.to(device)
        y_data = torch.eye(10)[y_data].to(device)
        lower_bound, loss = model.test({"x": x_data.view(-1, 784), "y": y_data})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [10]:
def plot_reconstrunction_missing(x_data):
    with torch.no_grad():
        z = q_x.sample({"x": x_data.view(-1, 784)}, return_all=False)
        recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x_data, recon_batch]).cpu()
        return comparison
    
def plot_image_from_label(x_data, y_data):
    with torch.no_grad():
        x_all = [x_data]
        for i in range(7):
            z = q_y.sample({"y": y_data}, return_all=False)
            recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
            x_all.append(recon_batch)
    
        comparison = torch.cat(x_all).cpu()
        return comparison

def plot_reconstrunction(x_data, y_data):
    with torch.no_grad():
        z = q.sample({"x": x_data.view(-1, 784), "y": y_data}, return_all=False)
        recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x_data, recon_batch]).cpu()
        return comparison

In [11]:
writer = SummaryWriter()

plot_number = 1

x_original, y_original = iter(test_loader).next()
x_original = x_original.to(device)
y_original = torch.eye(10)[y_original].to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(x_original[:8], y_original[:8])
    sample = plot_image_from_label(x_original[:8], y_original[:8])
    recon_missing = plot_reconstrunction_missing(x_original[:8])

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      

    writer.add_image('Image_from_label', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)    
    writer.add_image('Image_reconstrunction_missing', recon_missing, epoch)
    
writer.close()

100%|██████████| 469/469 [00:09<00:00, 47.00it/s]

Epoch: 1 Train loss: 187.1016



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 150.9254


100%|██████████| 469/469 [00:09<00:00, 48.70it/s]


Epoch: 2 Train loss: 135.1162


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 126.5083


100%|██████████| 469/469 [00:09<00:00, 48.88it/s]

Epoch: 3 Train loss: 120.2283



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.4259


100%|██████████| 469/469 [00:09<00:00, 46.99it/s]

Epoch: 4 Train loss: 112.8219



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 111.0504


100%|██████████| 469/469 [00:09<00:00, 48.93it/s]


Epoch: 5 Train loss: 108.7513


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 107.4728


100%|██████████| 469/469 [00:09<00:00, 48.39it/s]


Epoch: 6 Train loss: 106.1588


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.7631


100%|██████████| 469/469 [00:09<00:00, 49.01it/s]


Epoch: 7 Train loss: 104.4046


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.6237


100%|██████████| 469/469 [00:09<00:00, 48.83it/s]

Epoch: 8 Train loss: 102.8904



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.1322


100%|██████████| 469/469 [00:09<00:00, 48.81it/s]

Epoch: 9 Train loss: 101.7981



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.9587


100%|██████████| 469/469 [00:09<00:00, 48.02it/s]


Epoch: 10 Train loss: 100.9334


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.3035


100%|██████████| 469/469 [00:09<00:00, 48.90it/s]

Epoch: 11 Train loss: 100.1859



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.4606


100%|██████████| 469/469 [00:09<00:00, 50.00it/s]

Epoch: 12 Train loss: 99.5060



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5966


100%|██████████| 469/469 [00:09<00:00, 48.68it/s]


Epoch: 13 Train loss: 98.9987


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 99.7362


100%|██████████| 469/469 [00:09<00:00, 49.79it/s]

Epoch: 14 Train loss: 98.5388



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 99.2703


100%|██████████| 469/469 [00:09<00:00, 48.26it/s]


Epoch: 15 Train loss: 98.0679


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 99.3250


100%|██████████| 469/469 [00:09<00:00, 48.79it/s]


Epoch: 16 Train loss: 97.7188


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 98.9253


100%|██████████| 469/469 [00:10<00:00, 46.43it/s]


Epoch: 17 Train loss: 97.3542


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 98.6657


100%|██████████| 469/469 [00:09<00:00, 48.40it/s]

Epoch: 18 Train loss: 97.0262



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 98.5665


100%|██████████| 469/469 [00:09<00:00, 49.25it/s]


Epoch: 19 Train loss: 96.7406


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 98.1340


100%|██████████| 469/469 [00:09<00:00, 48.61it/s]


Epoch: 20 Train loss: 96.5522


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 98.0411


100%|██████████| 469/469 [00:09<00:00, 49.48it/s]

Epoch: 21 Train loss: 96.2881



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 98.0973


100%|██████████| 469/469 [00:09<00:00, 49.66it/s]

Epoch: 22 Train loss: 96.0790



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.7282


100%|██████████| 469/469 [00:09<00:00, 49.90it/s]


Epoch: 23 Train loss: 95.8639


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 98.1494


100%|██████████| 469/469 [00:09<00:00, 48.45it/s]

Epoch: 24 Train loss: 95.6620



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.6332


100%|██████████| 469/469 [00:09<00:00, 50.55it/s]


Epoch: 25 Train loss: 95.4658


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.4641


100%|██████████| 469/469 [00:09<00:00, 47.68it/s]


Epoch: 26 Train loss: 95.3702


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.6078


100%|██████████| 469/469 [00:09<00:00, 49.45it/s]

Epoch: 27 Train loss: 95.1595



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.6879


100%|██████████| 469/469 [00:09<00:00, 48.76it/s]

Epoch: 28 Train loss: 95.0351



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.0330


100%|██████████| 469/469 [00:09<00:00, 48.29it/s]

Epoch: 29 Train loss: 94.8886



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.0218


100%|██████████| 469/469 [00:09<00:00, 50.73it/s]


Epoch: 30 Train loss: 94.7849


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.0652


100%|██████████| 469/469 [00:07<00:00, 59.38it/s]


Epoch: 31 Train loss: 94.6073


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.8270


100%|██████████| 469/469 [00:07<00:00, 58.96it/s]

Epoch: 32 Train loss: 94.5383



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.0149


100%|██████████| 469/469 [00:08<00:00, 55.62it/s]


Epoch: 33 Train loss: 94.3535


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.8534


100%|██████████| 469/469 [00:08<00:00, 55.86it/s]


Epoch: 34 Train loss: 94.2371


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 97.0321


100%|██████████| 469/469 [00:08<00:00, 55.28it/s]


Epoch: 35 Train loss: 94.2173


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.8754


100%|██████████| 469/469 [00:08<00:00, 57.55it/s]


Epoch: 36 Train loss: 94.0628


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.4432


100%|██████████| 469/469 [00:07<00:00, 58.69it/s]


Epoch: 37 Train loss: 94.0245


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.7406


100%|██████████| 469/469 [00:08<00:00, 54.36it/s]


Epoch: 38 Train loss: 93.9131


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.6141


100%|██████████| 469/469 [00:08<00:00, 53.66it/s]


Epoch: 39 Train loss: 93.8186


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.5718


100%|██████████| 469/469 [00:09<00:00, 50.01it/s]


Epoch: 40 Train loss: 93.7387


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.3879


100%|██████████| 469/469 [00:09<00:00, 51.67it/s]


Epoch: 41 Train loss: 93.6015


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.3639


100%|██████████| 469/469 [00:09<00:00, 50.43it/s]

Epoch: 42 Train loss: 93.5925



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.2859


100%|██████████| 469/469 [00:09<00:00, 52.00it/s]

Epoch: 43 Train loss: 93.4854



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.4018


100%|██████████| 469/469 [00:09<00:00, 51.30it/s]


Epoch: 44 Train loss: 93.4195


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.4056


100%|██████████| 469/469 [00:08<00:00, 53.23it/s]

Epoch: 45 Train loss: 93.3707



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.3856


100%|██████████| 469/469 [00:08<00:00, 52.27it/s]


Epoch: 46 Train loss: 93.3202


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.4188


100%|██████████| 469/469 [00:09<00:00, 50.31it/s]


Epoch: 47 Train loss: 93.1943


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.2185


100%|██████████| 469/469 [00:08<00:00, 52.31it/s]

Epoch: 48 Train loss: 93.1096



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.4103


100%|██████████| 469/469 [00:08<00:00, 53.61it/s]


Epoch: 49 Train loss: 93.0473


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.0691


100%|██████████| 469/469 [00:08<00:00, 52.66it/s]

Epoch: 50 Train loss: 93.0392





Test loss: 96.1682
