# MVAE
* Original paper: Multimodal Generative Models for Scalable Weakly-Supervised Learning (https://papers.nips.cc/paper/7801-multimodal-generative-models-for-scalable-weakly-supervised-learning.pdf)
* Original code: https://github.com/mhw32/multimodal-vae-public


### MVAE summary  
Multimodal variational autoencoder(MVAE) uses a product-of-experts inferece network and a sub-sampled training paradigm to solve the multi-modal inferece problem.  
- Product-of-experts  
In the multimodal setting we assume the N modalities, $x_{1}, x_{2}, ..., x_{N}$, are conditionally independent given the common latent variable, z. That is we assume a generative model of the form $p_{\theta}(x_{1}, x_{2}, ..., x_{N}, z) = p(z)p_{\theta}(x_{1}|z)p_{\theta}(x_{2}|z)$・・・$p_{\theta}(x_{N}|z)$. The conditional independence assumptions in the generative model imply a relation among joint- and simgle-modality posteriors. That is, the joint posterior is a procuct of individual posteriors, with an additional quotient by the prior.  

- Sub-sampled training  
MVAE sub-sample which ELBO terms to optimize for every gradient step for capturing the relationships between modalities and training individual inference networks.  

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 100
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
# MNIST
# treat labels as a second modality
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.utils import print_latex

## Define probability distributions
### In the original paper
Modalities: $x_{1}, x_{2}, ..., x_{N}$  
Generative model:  

$p_{\theta}\left(x_{1}, x_{2}, \ldots, x_{N}, z\right)=p(z) p_{\theta}\left(x_{1} | z\right) p_{\theta}\left(x_{2} | z\right) \cdots p_{\theta}\left(x_{N} | z\right)$  

Inference:  

$p\left(z | x_{1}, \ldots, x_{N}\right) \propto \frac{\prod_{i=1}^{N} p\left(z | x_{i}\right)}{\prod_{i=1}^{N-1} p(z)} \approx \frac{\prod_{i=1}^{N}\left[\tilde{q}\left(z | x_{i}\right) p(z)\right]}{\prod_{i=1}^{N-1} p(z)}=p(z) \prod_{i=1}^{N} \tilde{q}\left(z | x_{i}\right)$  

### MNIST settings
Modalities:
- x for image modality
- y for label modality

Prior: $p(z) = \cal N(z; \mu=0, \sigma^2=1)$  
Generators:  
$p_{\theta}(x|z) = \cal B(x; \lambda = g_x(z))$ for image modality  
$p_{\theta}(y|z) = \cal Cat(y; \lambda = g_y(z))$ for label modality  
$p_{\theta}\left(x, y, z\right)=p(z) p_{\theta}(x| z) p_{\theta}(y | z)$

Inferences:  
$q_{\phi}(z|x) = \cal N(z; \mu=fx_\mu(x), \sigma^2=fx_{\sigma^2}(x))$ for image modality  
$q_{\phi}(z|y) = \cal N(z; \mu=fy_\mu(y), \sigma^2=fy_{\sigma^2}(y))$ for label modality  
$p(z)q_{\phi}(z|x)q_{\phi}(z|y)$


In [4]:
from pixyz.distributions import Normal, Bernoulli, Categorical, ProductOfNormal

x_dim = 784
y_dim = 10
z_dim = 64


# inference model q(z|x) for image modality
class InferenceX(Normal):
    def __init__(self):
        super(InferenceX, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}


# inference model q(z|y) for label modality
class InferenceY(Normal):
    def __init__(self):
        super(InferenceY, self).__init__(cond_var=["y"], var=["z"], name="q")

        self.fc1 = nn.Linear(y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, y):
        h = F.relu(self.fc1(y))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z) 
class GeneratorX(Bernoulli):
    def __init__(self):
        super(GeneratorX, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}


# generative model p(y|z)    
class GeneratorY(Categorical):
    def __init__(self):
        super(GeneratorY, self).__init__(cond_var=["z"], var=["y"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, y_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": F.softmax(self.fc3(h), dim=1)}

    
# prior model p(z)
prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)

p_x = GeneratorX().to(device)
p_y = GeneratorY().to(device)
p = p_x * p_y

q_x = InferenceX().to(device)
q_y = InferenceY().to(device)

# equation (4) in the paper
# "we can use a product of experts (PoE), including a “prior expert”, as the approximating distribution for the joint-posterior"
# Pixyz docs: https://docs.pixyz.io/en/latest/distributions.html#pixyz.distributions.ProductOfNormal
q = ProductOfNormal([q_x, q_y], name="q").to(device)

In [5]:
print(q)
print_latex(q)

Distribution:
  q(z|x,y) \propto p(z)q(z|x)q(z|y)
Network architecture:
  ProductOfNormal(
    name=q, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (p): ModuleList(
      (0): InferenceX(
        name=q, distribution_name=Normal,
        var=['z'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
        (fc1): Linear(in_features=784, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=512, bias=True)
        (fc31): Linear(in_features=512, out_features=64, bias=True)
        (fc32): Linear(in_features=512, out_features=64, bias=True)
      )
      (1): InferenceY(
        name=q, distribution_name=Normal,
        var=['z'], cond_var=['y'], input_var=['y'], features_shape=torch.Size([])
        (fc1): Linear(in_features=10, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=512, bias=True)
        (fc31): Linear(in_features=512, out_features=

<IPython.core.display.Math object>

In [6]:
print(p)
print_latex(p)

Distribution:
  p(x,y|z) = p(x|z)p(y|z)
Network architecture:
  p(y|z):
  GeneratorY(
    name=p, distribution_name=Categorical,
    var=['y'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=10, bias=True)
  )
  p(x|z):
  GeneratorX(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

## Define Loss function
$\cal L = \mathrm{ELBO}\left(x_{1}, \ldots, x_{N}\right)+\sum_{i=1}^{N} \mathrm{ELBO}\left(x_{i}\right)+\sum_{j=1}^{k} \mathrm{ELBO}\left(X_{j}\right)$

In [7]:
from pixyz.losses import KullbackLeibler
from pixyz.losses import LogProb
from pixyz.losses import Expectation as E

In [8]:
ELBO = -E(q, LogProb(p)) + KullbackLeibler(q, prior)
ELBO_x = -E(q_x, LogProb(p_x)) + KullbackLeibler(q_x, prior)
ELBO_y = -E(q_y, LogProb(p_y)) + KullbackLeibler(q_y, prior)

loss = ELBO.mean() + ELBO_x.mean() + ELBO_y.mean()
print_latex(loss) # Note: Terms in the printed loss may be reordered

<IPython.core.display.Math object>

## Define MVAE model using Model Class

In [9]:
from pixyz.models import Model
model = Model(loss=loss, distributions=[p_x, p_y, q_x, q_y],
              optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)
print_latex(model)

Distributions (for training): 
  p(x|z), p(y|z), q(z|x), q(z|y) 
Loss function: 
  mean \left(D_{KL} \left[q(z|x)||p_{prior}(z) \right] - \mathbb{E}_{q(z|x)} \left[\log p(x|z) \right] \right) + mean \left(D_{KL} \left[q(z|x,y)||p_{prior}(z) \right] - \mathbb{E}_{q(z|x,y)} \left[\log p(x,y|z) \right] \right) + mean \left(D_{KL} \left[q(z|y)||p_{prior}(z) \right] - \mathbb{E}_{q(z|y)} \left[\log p(y|z) \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


<IPython.core.display.Math object>

## Define Train and Test loop using model

In [10]:
def train(epoch):
    train_loss = 0
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = torch.eye(10)[y].to(device)        
        loss = model.train({"x": x, "y": y})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [11]:
def test(epoch):
    test_loss = 0
    for x, y in test_loader:
        x = x.to(device)
        y = torch.eye(10)[y].to(device)
        loss = model.test({"x": x, "y": y})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

## Reconstruction and generation

In [12]:
def plot_reconstrunction_missing_label_modality(x):
    with torch.no_grad():
        # infer from x (image modality) only
        z = q_x.sample({"x": x}, return_all=False)
        # generate image from latent variable
        recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_label(x, y):
    with torch.no_grad():
        x_all = [x.view(-1, 1, 28, 28)]
        for i in range(7):
            # infer from y (label modality) only
            z = q_y.sample({"y": y}, return_all=False)
            
            # generate image from latent variable
            recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
            x_all.append(recon_batch)
    
        comparison = torch.cat(x_all).cpu()
        return comparison

def plot_reconstrunction(x, y):
    with torch.no_grad():
        # infer from x and y
        z = q.sample({"x": x, "y": y}, return_all=False)
        # generate image from latent variable
        recon_batch = p_x.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison

In [13]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [14]:
# for visualising in TensorBoard
import pixyz
v = pixyz.__version__
writer = SummaryWriter("runs/" + v + ".mvae"  + exp_time)


plot_number = 1

# set-aside observation for watching generative model improvement 
_x, _y = iter(test_loader).next()
_x = _x.to(device)
_y = torch.eye(10)[_y].to(device)

import time
start = time.time()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8], _y[:8])
    sample = plot_image_from_label(_x[:8], _y[:8])
    recon_missing = plot_reconstrunction_missing_label_modality(_x[:8])

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      

    writer.add_images('Image_from_label', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)    
    writer.add_images('Image_reconstrunction_missing_label', recon_missing, epoch)
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time) 
writer.close()

100%|██████████| 469/469 [00:12<00:00, 36.22it/s]

Epoch: 1 Train loss: 376.3922



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 294.6321


100%|██████████| 469/469 [00:13<00:00, 35.25it/s]

Epoch: 2 Train loss: 274.7651



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 263.8580


100%|██████████| 469/469 [00:12<00:00, 37.28it/s]


Epoch: 3 Train loss: 256.1177


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 251.5709


100%|██████████| 469/469 [00:13<00:00, 35.42it/s]


Epoch: 4 Train loss: 247.8665


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 246.4499


100%|██████████| 469/469 [00:13<00:00, 34.75it/s]


Epoch: 5 Train loss: 242.3597


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 240.8372


100%|██████████| 469/469 [00:13<00:00, 35.33it/s]

Epoch: 6 Train loss: 237.6370



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 237.6951


100%|██████████| 469/469 [00:13<00:00, 34.68it/s]

Epoch: 7 Train loss: 234.2372



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 234.4837


100%|██████████| 469/469 [00:12<00:00, 36.61it/s]


Epoch: 8 Train loss: 231.9510


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 232.8126


100%|██████████| 469/469 [00:13<00:00, 34.84it/s]


Epoch: 9 Train loss: 230.0487


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 231.4923


100%|██████████| 469/469 [00:12<00:00, 36.50it/s]

Epoch: 10 Train loss: 228.5014



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 230.8287


100%|██████████| 469/469 [00:12<00:00, 36.53it/s]

Epoch: 11 Train loss: 227.2024



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 229.4085


100%|██████████| 469/469 [00:13<00:00, 36.01it/s]


Epoch: 12 Train loss: 225.8663


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 228.5348


100%|██████████| 469/469 [00:12<00:00, 36.46it/s]

Epoch: 13 Train loss: 224.5877



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 227.2602


100%|██████████| 469/469 [00:13<00:00, 34.72it/s]


Epoch: 14 Train loss: 223.7152


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 226.6718


100%|██████████| 469/469 [00:12<00:00, 37.69it/s]

Epoch: 15 Train loss: 222.9457



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 226.4919


100%|██████████| 469/469 [00:12<00:00, 37.60it/s]

Epoch: 16 Train loss: 222.2629



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 225.7478


100%|██████████| 469/469 [00:12<00:00, 36.57it/s]


Epoch: 17 Train loss: 221.6597


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 225.1287


100%|██████████| 469/469 [00:14<00:00, 33.03it/s]


Epoch: 18 Train loss: 221.1593


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 224.9242


100%|██████████| 469/469 [00:12<00:00, 36.97it/s]


Epoch: 19 Train loss: 220.6716


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 224.6173


100%|██████████| 469/469 [00:13<00:00, 34.55it/s]


Epoch: 20 Train loss: 220.2057


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 224.1306


100%|██████████| 469/469 [00:13<00:00, 35.20it/s]

Epoch: 21 Train loss: 219.7928



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 224.3231


100%|██████████| 469/469 [00:13<00:00, 35.82it/s]

Epoch: 22 Train loss: 219.4239



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.6848


100%|██████████| 469/469 [00:17<00:00, 27.11it/s]


Epoch: 23 Train loss: 219.0340


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.7242


100%|██████████| 469/469 [00:20<00:00, 22.99it/s]

Epoch: 24 Train loss: 218.7304



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.6446


100%|██████████| 469/469 [00:20<00:00, 23.34it/s]


Epoch: 25 Train loss: 218.4557


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 224.1697


100%|██████████| 469/469 [00:19<00:00, 24.66it/s]


Epoch: 26 Train loss: 218.1474


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.4451


100%|██████████| 469/469 [00:18<00:00, 25.19it/s]

Epoch: 27 Train loss: 217.8321



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.2045


100%|██████████| 469/469 [00:19<00:00, 24.57it/s]

Epoch: 28 Train loss: 217.6219



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 223.0786


100%|██████████| 469/469 [00:19<00:00, 24.34it/s]


Epoch: 29 Train loss: 217.3562


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.6729


100%|██████████| 469/469 [00:19<00:00, 24.47it/s]


Epoch: 30 Train loss: 217.2269


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.4899


100%|██████████| 469/469 [00:19<00:00, 24.52it/s]

Epoch: 31 Train loss: 216.8729



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.6323


100%|██████████| 469/469 [00:19<00:00, 24.54it/s]


Epoch: 32 Train loss: 216.6449


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.3912


100%|██████████| 469/469 [00:19<00:00, 24.25it/s]


Epoch: 33 Train loss: 216.5599


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.3707


100%|██████████| 469/469 [00:14<00:00, 32.39it/s]


Epoch: 34 Train loss: 216.3555


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.3637


100%|██████████| 469/469 [00:12<00:00, 38.20it/s]

Epoch: 35 Train loss: 216.1228



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.2030


100%|██████████| 469/469 [00:12<00:00, 36.36it/s]


Epoch: 36 Train loss: 215.9985


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.7048


100%|██████████| 469/469 [00:14<00:00, 32.14it/s]

Epoch: 37 Train loss: 215.8160



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.2223


100%|██████████| 469/469 [00:14<00:00, 33.05it/s]


Epoch: 38 Train loss: 215.6465


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.3870


100%|██████████| 469/469 [00:13<00:00, 34.88it/s]


Epoch: 39 Train loss: 215.5218


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.7112


100%|██████████| 469/469 [00:14<00:00, 32.66it/s]


Epoch: 40 Train loss: 215.3924


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.2654


100%|██████████| 469/469 [00:12<00:00, 37.42it/s]

Epoch: 41 Train loss: 215.1894



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 222.2196


100%|██████████| 469/469 [00:14<00:00, 32.75it/s]

Epoch: 42 Train loss: 215.0450



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.6285


100%|██████████| 469/469 [00:14<00:00, 32.61it/s]


Epoch: 43 Train loss: 214.9800


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.9753


100%|██████████| 469/469 [00:14<00:00, 32.56it/s]


Epoch: 44 Train loss: 214.8128


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.6960


100%|██████████| 469/469 [00:14<00:00, 32.22it/s]


Epoch: 45 Train loss: 214.7407


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.7371


100%|██████████| 469/469 [00:12<00:00, 37.62it/s]

Epoch: 46 Train loss: 214.5490



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.7999


100%|██████████| 469/469 [00:14<00:00, 31.67it/s]

Epoch: 47 Train loss: 214.4453



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.8517


100%|██████████| 469/469 [00:14<00:00, 32.09it/s]


Epoch: 48 Train loss: 214.4403


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.7722


100%|██████████| 469/469 [00:12<00:00, 37.13it/s]

Epoch: 49 Train loss: 214.2940



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.4170


100%|██████████| 469/469 [00:12<00:00, 36.59it/s]


Epoch: 50 Train loss: 214.0828


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.9401


100%|██████████| 469/469 [00:12<00:00, 37.44it/s]

Epoch: 51 Train loss: 214.0985



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1478


100%|██████████| 469/469 [00:13<00:00, 35.27it/s]


Epoch: 52 Train loss: 213.9334


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.2523


100%|██████████| 469/469 [00:12<00:00, 37.26it/s]


Epoch: 53 Train loss: 213.8483


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.2841


100%|██████████| 469/469 [00:11<00:00, 39.23it/s]


Epoch: 54 Train loss: 213.8177


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.3116


100%|██████████| 469/469 [00:13<00:00, 33.66it/s]


Epoch: 55 Train loss: 213.6946


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1464


100%|██████████| 469/469 [00:12<00:00, 37.54it/s]


Epoch: 56 Train loss: 213.5374


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.6107


100%|██████████| 469/469 [00:13<00:00, 35.82it/s]


Epoch: 57 Train loss: 213.5830


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9189


100%|██████████| 469/469 [00:12<00:00, 38.87it/s]


Epoch: 58 Train loss: 213.3468


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9887


100%|██████████| 469/469 [00:14<00:00, 32.45it/s]


Epoch: 59 Train loss: 213.3456


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.2518


100%|██████████| 469/469 [00:13<00:00, 34.85it/s]


Epoch: 60 Train loss: 213.3012


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.2192


100%|██████████| 469/469 [00:14<00:00, 32.54it/s]


Epoch: 61 Train loss: 213.1609


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0220


100%|██████████| 469/469 [00:14<00:00, 33.11it/s]

Epoch: 62 Train loss: 213.0940



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1166


100%|██████████| 469/469 [00:13<00:00, 34.98it/s]

Epoch: 63 Train loss: 213.1042



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.7964


100%|██████████| 469/469 [00:12<00:00, 36.43it/s]

Epoch: 64 Train loss: 212.9412



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9865


100%|██████████| 469/469 [00:13<00:00, 35.56it/s]


Epoch: 65 Train loss: 212.8790


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1034


100%|██████████| 469/469 [00:12<00:00, 36.47it/s]


Epoch: 66 Train loss: 212.7782


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.8977


100%|██████████| 469/469 [00:13<00:00, 35.11it/s]


Epoch: 67 Train loss: 212.7028


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9088


100%|██████████| 469/469 [00:13<00:00, 36.05it/s]


Epoch: 68 Train loss: 212.7070


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9410


100%|██████████| 469/469 [00:14<00:00, 32.37it/s]


Epoch: 69 Train loss: 212.5858


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0347


100%|██████████| 469/469 [00:12<00:00, 36.27it/s]


Epoch: 70 Train loss: 212.5615


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0489


100%|██████████| 469/469 [00:13<00:00, 35.67it/s]


Epoch: 71 Train loss: 212.4930


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.7736


100%|██████████| 469/469 [00:12<00:00, 37.42it/s]


Epoch: 72 Train loss: 212.3740


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0609


100%|██████████| 469/469 [00:13<00:00, 34.97it/s]


Epoch: 73 Train loss: 212.3673


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.2186


100%|██████████| 469/469 [00:14<00:00, 32.17it/s]

Epoch: 74 Train loss: 212.3291



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1695


100%|██████████| 469/469 [00:12<00:00, 36.33it/s]


Epoch: 75 Train loss: 212.2482


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.8649


100%|██████████| 469/469 [00:12<00:00, 37.01it/s]


Epoch: 76 Train loss: 212.2115


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9969


100%|██████████| 469/469 [00:12<00:00, 36.89it/s]

Epoch: 77 Train loss: 212.0289



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.8098


100%|██████████| 469/469 [00:12<00:00, 37.74it/s]


Epoch: 78 Train loss: 212.0733


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9762


100%|██████████| 469/469 [00:12<00:00, 36.71it/s]

Epoch: 79 Train loss: 211.9995



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.2490


100%|██████████| 469/469 [00:13<00:00, 34.11it/s]


Epoch: 80 Train loss: 211.9210


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1049


100%|██████████| 469/469 [00:13<00:00, 35.84it/s]


Epoch: 81 Train loss: 211.8805


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.5913


100%|██████████| 469/469 [00:13<00:00, 34.62it/s]


Epoch: 82 Train loss: 211.8168


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1392


100%|██████████| 469/469 [00:13<00:00, 35.75it/s]


Epoch: 83 Train loss: 211.7886


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9708


100%|██████████| 469/469 [00:13<00:00, 35.72it/s]

Epoch: 84 Train loss: 211.7169



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9415


100%|██████████| 469/469 [00:13<00:00, 36.05it/s]


Epoch: 85 Train loss: 211.7905


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0208


100%|██████████| 469/469 [00:13<00:00, 35.24it/s]


Epoch: 86 Train loss: 211.6630


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1261


100%|██████████| 469/469 [00:12<00:00, 37.33it/s]


Epoch: 87 Train loss: 211.6960


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.6828


100%|██████████| 469/469 [00:15<00:00, 30.68it/s]

Epoch: 88 Train loss: 211.5540



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0250


100%|██████████| 469/469 [00:13<00:00, 33.79it/s]


Epoch: 89 Train loss: 211.4605


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.7443


100%|██████████| 469/469 [00:14<00:00, 31.77it/s]

Epoch: 90 Train loss: 211.4642



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.8418


100%|██████████| 469/469 [00:12<00:00, 36.34it/s]


Epoch: 91 Train loss: 211.3731


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.7709


100%|██████████| 469/469 [00:13<00:00, 35.62it/s]


Epoch: 92 Train loss: 211.2876


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0692


100%|██████████| 469/469 [00:13<00:00, 34.80it/s]

Epoch: 93 Train loss: 211.3600



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.3120


100%|██████████| 469/469 [00:12<00:00, 37.41it/s]


Epoch: 94 Train loss: 211.3147


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 220.9525


100%|██████████| 469/469 [00:14<00:00, 32.36it/s]


Epoch: 95 Train loss: 211.2393


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0957


100%|██████████| 469/469 [00:13<00:00, 36.04it/s]


Epoch: 96 Train loss: 211.2623


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.2357


100%|██████████| 469/469 [00:14<00:00, 32.06it/s]


Epoch: 97 Train loss: 211.1985


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1565


100%|██████████| 469/469 [00:12<00:00, 37.57it/s]


Epoch: 98 Train loss: 211.2340


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.1601


100%|██████████| 469/469 [00:13<00:00, 35.57it/s]


Epoch: 99 Train loss: 211.0508


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 221.0042


100%|██████████| 469/469 [00:13<00:00, 34.87it/s]


Epoch: 100 Train loss: 211.0653
Test loss: 220.9757
