# Joint multimodal variational autoencoder (JMVAE, using the VAE class)

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 10
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, Bernoulli, Categorical
from pixyz.losses import KullbackLeibler, StochasticReconstructionLoss
from pixyz.models import VAE

In [4]:
x_dim = 784
y_dim = 10
z_dim = 64


# inference model q(z|x,y)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x","y"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim+y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x, y):
        h = F.relu(self.fc1(torch.cat([x, y], 1)))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# inference model q(z|x)
class InferenceX(Normal):
    def __init__(self):
        super(InferenceX, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# inference model q(z|y)
class InferenceY(Normal):
    def __init__(self):
        super(InferenceY, self).__init__(cond_var=["y"], var=["z"], name="q")

        self.fc1 = nn.Linear(y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, y):
        h = F.relu(self.fc1(y))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z)    
class GeneratorX(Bernoulli):
    def __init__(self):
        super(GeneratorX, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
    
# generative model p(y|z)    
class GeneratorY(Categorical):
    def __init__(self):
        super(GeneratorY, self).__init__(cond_var=["z"], var=["y"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, y_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": F.softmax(self.fc3(h), dim=1)}

    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

In [5]:
p_x = GeneratorX().to(device)
p_y = GeneratorY().to(device)

q = Inference().to(device)
q_x = InferenceX().to(device)
q_y = InferenceY().to(device)

p = p_x * p_y
print(p)

Distribution:
  p(x,y|z) = p(x|z)p(y|z)
Network architecture:
  p(y|z) (Categorical): GeneratorY(
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=10, bias=True)
  )
  p(x|z) (Bernoulli): GeneratorX(
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )


In [6]:
kl = KullbackLeibler(q, prior)
kl_x = KullbackLeibler(q, q_x)
kl_y = KullbackLeibler(q, q_y)

regularizer = kl + kl_x + kl_y
print(regularizer)

KL[q(z|x,y)||p_prior(z)] + KL[q(z|x,y)||q(z|x)] + KL[q(z|x,y)||q(z|y)]


In [7]:
model = VAE(q, p, other_distributions=[q_x, q_y],
            regularizer=regularizer, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

Distributions (for training): 
  q(z|x,y), p(x,y|z), q(z|x), q(z|y) 
Loss function: 
  mean(-E_q(z|x,y)[log p(x,y|z)] + KL[q(z|x,y)||p_prior(z)] + KL[q(z|x,y)||q(z|x)] + KL[q(z|x,y)||q(z|y)]) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


In [8]:
def train(epoch):
    train_loss = 0
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = torch.eye(10)[y].to(device)        
        loss = model.train({"x": x, "y": y})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [9]:
def test(epoch):
    test_loss = 0
    for x, y in test_loader:
        x = x.to(device)
        y = torch.eye(10)[y].to(device)
        loss = model.test({"x": x, "y": y})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [10]:
def plot_reconstrunction_missing(x):
    with torch.no_grad():
        z = q_x.sample({"x": x}, return_all=False)
        recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_label(x, y):
    with torch.no_grad():
        x_all = [x.view(-1, 1, 28, 28)]
        for i in range(7):
            z = q_y.sample({"y": y}, return_all=False)
            recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
            x_all.append(recon_batch)
    
        comparison = torch.cat(x_all).cpu()
        return comparison

def plot_reconstrunction(x, y):
    with torch.no_grad():
        z = q.sample({"x": x, "y": y}, return_all=False)
        recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison

In [11]:
writer = SummaryWriter()

plot_number = 1

_x, _y = iter(test_loader).next()
_x = _x.to(device)
_y = torch.eye(10)[_y].to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8], _y[:8])
    sample = plot_image_from_label(_x[:8], _y[:8])
    recon_missing = plot_reconstrunction_missing(_x[:8])

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      

    writer.add_image('Image_from_label', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)    
    writer.add_image('Image_reconstrunction_missing', recon_missing, epoch)
    
writer.close()

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7f9d94191940>>
Traceback (most recent call last):
  File "/home/masa/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 399, in __del__
    self._shutdown_workers()
  File "/home/masa/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 378, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/masa/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/masa/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 151, in rebuild_storage_fd
    fd = df.detach()
  File "/home/masa/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) 

Epoch: 1 Train loss: 191.1362



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 162.3838


100%|██████████| 469/469 [00:09<00:00, 47.17it/s]


Epoch: 2 Train loss: 150.1102


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 142.2189


100%|██████████| 469/469 [00:10<00:00, 46.48it/s]


Epoch: 3 Train loss: 136.7874


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 134.0676


100%|██████████| 469/469 [00:10<00:00, 46.66it/s]


Epoch: 4 Train loss: 131.5395


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 129.9764


100%|██████████| 469/469 [00:10<00:00, 46.55it/s]


Epoch: 5 Train loss: 128.0573


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 127.1851


100%|██████████| 469/469 [00:10<00:00, 44.65it/s]


Epoch: 6 Train loss: 126.1412


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 126.1046


100%|██████████| 469/469 [00:11<00:00, 42.25it/s]


Epoch: 7 Train loss: 124.7065


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 124.9047


100%|██████████| 469/469 [00:09<00:00, 48.69it/s]

Epoch: 8 Train loss: 123.6200



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 124.3587


100%|██████████| 469/469 [00:09<00:00, 47.98it/s]


Epoch: 9 Train loss: 122.7673


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 123.4335


100%|██████████| 469/469 [00:10<00:00, 46.43it/s]


Epoch: 10 Train loss: 121.9841
Test loss: 122.5366
