# MVAE
* Original paper: Multimodal Generative Models for Scalable Weakly-Supervised Learning (https://papers.nips.cc/paper/7801-multimodal-generative-models-for-scalable-weakly-supervised-learning.pdf)
* Original code: https://github.com/mhw32/multimodal-vae-public


### MVAE summary  
Multimodal variational autoencoder(MVAE) uses a product-of-experts inferece network and a sub-sampled training paradigm to solve the multi-modal inferece problem.  
- Product-of-experts  
In the multimodal setting we assume the N modalities, $x_{1}, x_{2}, ..., x_{N}$, are conditionally independent given the common latent variable, z. That is we assume a generative model of the form $p_{\theta}(x_{1}, x_{2}, ..., x_{N}, z) = p(z)p_{\theta}(x_{1}|z)p_{\theta}(x_{2}|z)$・・・$p_{\theta}(x_{N}|z)$. The conditional independence assumptions in the generative model imply a relation among joint- and simgle-modality posteriors. That is, the joint posterior is a procuct of individual posteriors, with an additional quotient by the prior.  

- Sub-sampled training  
MVAE sub-sample which ELBO terms to optimize for every gradient step for capturing the relationships between modalities and training individual inference networks.  

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 100
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
# MNIST
# treat labels as a second modality
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.utils import print_latex

## Define probability distributions
### In the original paper
Modalities: $x_{1}, x_{2}, ..., x_{N}$  
Generative model:  

$p_{\theta}\left(x_{1}, x_{2}, \ldots, x_{N}, z\right)=p(z) p_{\theta}\left(x_{1} | z\right) p_{\theta}\left(x_{2} | z\right) \cdots p_{\theta}\left(x_{N} | z\right)$  

Inference:  

$p\left(z | x_{1}, \ldots, x_{N}\right) \propto \frac{\prod_{i=1}^{N} p\left(z | x_{i}\right)}{\prod_{i=1}^{N-1} p(z)} \approx \frac{\prod_{i=1}^{N}\left[\tilde{q}\left(z | x_{i}\right) p(z)\right]}{\prod_{i=1}^{N-1} p(z)}=p(z) \prod_{i=1}^{N} \tilde{q}\left(z | x_{i}\right)$  

### MNIST settings
Modalities:
- x for image modality
- y for label modality

Prior: $p(z) = \cal N(z; \mu=0, \sigma^2=1)$  
Generators:  
$p_{\theta}(x|z) = \cal B(x; \lambda = g_x(z))$ for image modality  
$p_{\theta}(y|z) = \cal Cat(y; \lambda = g_y(z))$ for label modality  
$p_{\theta}\left(x, y, z\right)=p(z) p_{\theta}(x| z) p_{\theta}(y | z)$

Inferences:  
$q_{\phi}(z|x) = \cal N(z; \mu=fx_\mu(x), \sigma^2=fx_{\sigma^2}(x))$ for image modality  
$q_{\phi}(z|y) = \cal N(z; \mu=fy_\mu(y), \sigma^2=fy_{\sigma^2}(y))$ for label modality  
$p(z)q_{\phi}(z|x)q_{\phi}(z|y)$


In [4]:
from pixyz.distributions import Normal, Bernoulli, Categorical, ProductOfNormal

x_dim = 784
y_dim = 10
z_dim = 64


# inference model q(z|x) for image modality
class InferenceX(Normal):
    def __init__(self):
        super(InferenceX, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}


# inference model q(z|y) for label modality
class InferenceY(Normal):
    def __init__(self):
        super(InferenceY, self).__init__(cond_var=["y"], var=["z"], name="q")

        self.fc1 = nn.Linear(y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, y):
        h = F.relu(self.fc1(y))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z) 
class GeneratorX(Bernoulli):
    def __init__(self):
        super(GeneratorX, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}


# generative model p(y|z)    
class GeneratorY(Categorical):
    def __init__(self):
        super(GeneratorY, self).__init__(cond_var=["z"], var=["y"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, y_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": F.softmax(self.fc3(h), dim=1)}

    
# prior model p(z)
prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)

p_x = GeneratorX().to(device)
p_y = GeneratorY().to(device)
p = p_x * p_y

q_x = InferenceX().to(device)
q_y = InferenceY().to(device)

# equation (4) in the paper
# "we can use a product of experts (PoE), including a “prior expert”, as the approximating distribution for the joint-posterior"
# Pixyz docs: https://docs.pixyz.io/en/latest/distributions.html#pixyz.distributions.ProductOfNormal
q = ProductOfNormal([q_x, q_y], name="q").to(device)

In [5]:
print(q)
print_latex(q)

Distribution:
  q(z|x,y) \propto p(z)q(z|x)q(z|y)
Network architecture:
  ProductOfNormal(
    name=q, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (p): ModuleList(
      (0): InferenceX(
        name=q, distribution_name=Normal,
        var=['z'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
        (fc1): Linear(in_features=784, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=512, bias=True)
        (fc31): Linear(in_features=512, out_features=64, bias=True)
        (fc32): Linear(in_features=512, out_features=64, bias=True)
      )
      (1): InferenceY(
        name=q, distribution_name=Normal,
        var=['z'], cond_var=['y'], input_var=['y'], features_shape=torch.Size([])
        (fc1): Linear(in_features=10, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=512, bias=True)
        (fc31): Linear(in_features=512, out_features=

<IPython.core.display.Math object>

In [6]:
print(p)
print_latex(p)

Distribution:
  p(x,y|z) = p(x|z)p(y|z)
Network architecture:
  GeneratorY(
    name=p, distribution_name=Categorical,
    var=['y'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=10, bias=True)
  )
  GeneratorX(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

## Define Loss function
$\cal L = \mathrm{ELBO}\left(x_{1}, \ldots, x_{N}\right)+\sum_{i=1}^{N} \mathrm{ELBO}\left(x_{i}\right)+\sum_{j=1}^{k} \mathrm{ELBO}\left(X_{j}\right)$

In [7]:
from pixyz.losses import KullbackLeibler
from pixyz.losses import LogProb
from pixyz.losses import Expectation as E

In [8]:
ELBO = -E(q, LogProb(p)) + KullbackLeibler(q, prior)
ELBO_x = -E(q_x, LogProb(p_x)) + KullbackLeibler(q_x, prior)
ELBO_y = -E(q_y, LogProb(p_y)) + KullbackLeibler(q_y, prior)

loss = ELBO.mean() + ELBO_x.mean() + ELBO_y.mean()
print_latex(loss) # Note: Terms in the printed loss may be reordered

<IPython.core.display.Math object>

## Define MVAE model using Model Class

In [9]:
from pixyz.models import Model
model = Model(loss=loss, distributions=[p_x, p_y, q_x, q_y],
              optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)
print_latex(model)

Distributions (for training): 
  p(x|z), p(y|z), q(z|x), q(z|y) 
Loss function: 
  mean \left(D_{KL} \left[q(z|x)||p_{prior}(z) \right] - \mathbb{E}_{q(z|x)} \left[\log p(x|z) \right] \right) + mean \left(D_{KL} \left[q(z|x,y)||p_{prior}(z) \right] - \mathbb{E}_{q(z|x,y)} \left[\log p(x,y|z) \right] \right) + mean \left(D_{KL} \left[q(z|y)||p_{prior}(z) \right] - \mathbb{E}_{q(z|y)} \left[\log p(y|z) \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


<IPython.core.display.Math object>

## Define Train and Test loop using model

In [10]:
def train(epoch):
    train_loss = 0
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = torch.eye(10)[y].to(device)        
        loss = model.train({"x": x, "y": y})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [11]:
def test(epoch):
    test_loss = 0
    for x, y in test_loader:
        x = x.to(device)
        y = torch.eye(10)[y].to(device)
        loss = model.test({"x": x, "y": y})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

## Reconstruction and generation

In [12]:
def plot_reconstrunction_missing_label_modality(x):
    with torch.no_grad():
        # infer from x (image modality) only
        z = q_x.sample({"x": x}, return_all=False)
        # generate image from latent variable
        recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_label(x, y):
    with torch.no_grad():
        x_all = [x.view(-1, 1, 28, 28)]
        for i in range(7):
            # infer from y (label modality) only
            z = q_y.sample({"y": y}, return_all=False)
            
            # generate image from latent variable
            recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
            x_all.append(recon_batch)
    
        comparison = torch.cat(x_all).cpu()
        return comparison

def plot_reconstrunction(x, y):
    with torch.no_grad():
        # infer from x and y
        z = q.sample({"x": x, "y": y}, return_all=False)
        # generate image from latent variable
        recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison

In [None]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [13]:
# for visualising in TensorBoard
import pixyz
v = pixyz.__version__
writer = SummaryWriter("runs/" + v + ".mvae"  + exp_time)


plot_number = 1

# set-aside observation for watching generative model improvement 
_x, _y = iter(test_loader).next()
_x = _x.to(device)
_y = torch.eye(10)[_y].to(device)

import time
start = time.time()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8], _y[:8])
    sample = plot_image_from_label(_x[:8], _y[:8])
    recon_missing = plot_reconstrunction_missing_label_modality(_x[:8])

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      

    writer.add_images('Image_from_label', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)    
    writer.add_images('Image_reconstrunction_missing_label', recon_missing, epoch)
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time) 
writer.close()

100%|██████████| 469/469 [00:29<00:00, 15.98it/s]

Epoch: 1 Train loss: 377.4446



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 297.8822


100%|██████████| 469/469 [00:29<00:00, 15.87it/s]

Epoch: 2 Train loss: 275.6488



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 263.0328


100%|██████████| 469/469 [00:29<00:00, 15.85it/s]


Epoch: 3 Train loss: 255.8915


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 251.6184


100%|██████████| 469/469 [00:28<00:00, 16.47it/s]

Epoch: 4 Train loss: 246.7766



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 244.7460


100%|██████████| 469/469 [00:28<00:00, 16.52it/s]


Epoch: 5 Train loss: 240.5835


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 239.4675


100%|██████████| 469/469 [00:29<00:00, 16.03it/s]


Epoch: 6 Train loss: 236.3628


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 236.4568


100%|██████████| 469/469 [00:29<00:00, 16.15it/s]


Epoch: 7 Train loss: 233.0534


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 233.7300


100%|██████████| 469/469 [00:30<00:00, 15.43it/s]


Epoch: 8 Train loss: 230.8085


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 231.8841


100%|██████████| 469/469 [00:29<00:00, 15.84it/s]


Epoch: 9 Train loss: 229.0455


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 230.5546


100%|██████████| 469/469 [00:29<00:00, 16.07it/s]

Epoch: 10 Train loss: 227.6146



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 229.4565


100%|██████████| 469/469 [00:29<00:00, 15.85it/s]


Epoch: 11 Train loss: 226.2698


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 227.8539


100%|██████████| 469/469 [00:28<00:00, 16.43it/s]


Epoch: 12 Train loss: 224.9455


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 228.2061


100%|██████████| 469/469 [00:29<00:00, 15.69it/s]


Epoch: 13 Train loss: 224.0336


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 226.7854


100%|██████████| 469/469 [00:30<00:00, 15.48it/s]

Epoch: 14 Train loss: 223.1099



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 226.1085


100%|██████████| 469/469 [00:28<00:00, 16.32it/s]


Epoch: 15 Train loss: 222.4037


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 225.8082


100%|██████████| 469/469 [00:30<00:00, 15.55it/s]


Epoch: 16 Train loss: 221.7075


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 225.0872


100%|██████████| 469/469 [00:28<00:00, 16.24it/s]


Epoch: 17 Train loss: 221.2034


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 224.8409


100%|██████████| 469/469 [00:29<00:00, 15.80it/s]


Epoch: 18 Train loss: 220.7850


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 224.4209


100%|██████████| 469/469 [00:30<00:00, 15.56it/s]


Epoch: 19 Train loss: 220.2818


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.8716


100%|██████████| 469/469 [00:30<00:00, 15.23it/s]


Epoch: 20 Train loss: 219.7862


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.8801


100%|██████████| 469/469 [00:30<00:00, 15.21it/s]


Epoch: 21 Train loss: 219.3984


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.7502


100%|██████████| 469/469 [00:28<00:00, 16.39it/s]


Epoch: 22 Train loss: 219.1016


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.2690


100%|██████████| 469/469 [00:30<00:00, 15.49it/s]


Epoch: 23 Train loss: 218.7089


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.5728


100%|██████████| 469/469 [00:28<00:00, 16.34it/s]


Epoch: 24 Train loss: 218.5097


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.0242


100%|██████████| 469/469 [00:29<00:00, 16.03it/s]


Epoch: 25 Train loss: 218.2024


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.3938


100%|██████████| 469/469 [00:29<00:00, 16.17it/s]


Epoch: 26 Train loss: 217.8878


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.2734


100%|██████████| 469/469 [00:29<00:00, 16.17it/s]


Epoch: 27 Train loss: 217.6303


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.0314


100%|██████████| 469/469 [00:29<00:00, 15.76it/s]


Epoch: 28 Train loss: 217.3718


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.6581


100%|██████████| 469/469 [00:30<00:00, 15.63it/s]

Epoch: 29 Train loss: 217.0927



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.6385


100%|██████████| 469/469 [00:29<00:00, 16.15it/s]

Epoch: 30 Train loss: 216.8493



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.1671


100%|██████████| 469/469 [00:29<00:00, 15.70it/s]


Epoch: 31 Train loss: 216.7484


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.3781


100%|██████████| 469/469 [00:28<00:00, 16.47it/s]


Epoch: 32 Train loss: 216.4953


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.8737


  8%|▊         | 36/469 [00:02<00:31, 13.94it/s]