# FactorVAE

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, Bernoulli, Deterministic
from pixyz.losses import KullbackLeibler, CrossEntropy, AdversarialKullbackLeibler
from pixyz.models import Model
from pixyz.utils import print_latex

In [4]:
x_dim = 784
z_dim = 8


# inference model q(z|x)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# generative model p(x|z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
    
    
class InferenceShuffleDim(Deterministic):
    def __init__(self):
        super(InferenceShuffleDim, self).__init__(cond_var=["x_shf"], var=["z"], name="q_shf")

    def forward(self, x_shf):
        z = q.sample({"x": x_shf}, return_all=False)["z"]
        return {"z": z[:,torch.randperm(z.shape[1])]}
    

class Discriminator(Deterministic):
    def __init__(self):
        super(Discriminator, self).__init__(cond_var=["z"], var=["t"], name="d")

        self.model = nn.Sequential(
            nn.Linear(z_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        t = self.model(z)
        return {"t": t}

In [5]:
p = Generator().to(device)
q = Inference().to(device)
d = Discriminator().to(device)
q_shuffle = InferenceShuffleDim().to(device)
prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"],
               features_shape=[z_dim], name="p_prior").to(device)

In [6]:
print(p)
print(q)
print(q_shuffle)
print(d)
print(prior)

Distribution:
  p(x|z)
Network architecture:
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=8, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )
Distribution:
  q(z|x)
Network architecture:
  Inference(
    name=q, distribution_name=Normal,
    var=['z'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc31): Linear(in_features=512, out_features=8, bias=True)
    (fc32): Linear(in_features=512, out_features=8, bias=True)
  )
Distribution:
  q_{shf}(z|x_{shf})
Network architecture:
  InferenceShuffleDim(
    name=q_{shf}, distribution_name=Deterministic,
    var=['z'], cond_var=['x_shf'], input_var=['x_shf'], features_s

In [7]:
reconst = -p.log_prob().expectation(q)
kl = KullbackLeibler(q, prior)
tc = AdversarialKullbackLeibler(q, q_shuffle, discriminator=d, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
loss_cls = reconst.mean() + kl.mean() + 10*tc
loss_cls.to(device)
print(loss_cls)
print_latex(loss_cls)

mean \left(- \mathbb{E}_{q(z|x)} \left[\log p(x|z) \right] \right) + mean \left(D_{KL} \left[q(z|x)||p_{prior}(z) \right] \right) + 10 mean(D_{KL}^{Adv} \left[q(z|x)||q_{shf}(z|x_{shf}) \right])


<IPython.core.display.Math object>

In [8]:
model = Model(loss_cls, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

Distributions (for training): 
  p(x|z), q(z|x) 
Loss function: 
  mean \left(- \mathbb{E}_{q(z|x)} \left[\log p(x|z) \right] \right) + mean \left(D_{KL} \left[q(z|x)||p_{prior}(z) \right] \right) + 10 mean(D_{KL}^{Adv} \left[q(z|x)||q_{shf}(z|x_{shf}) \right]) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


In [9]:
def train(epoch):
    train_loss = 0
    train_d_loss = 0    
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        len_x = x.shape[0]//2
        loss = model.train({"x": x[:len_x], "x_shf": x[len_x:]})
        d_loss = tc.loss_train({"x": x[:len_x], "x_shf": x[len_x:]})
        train_loss += loss
        train_d_loss += d_loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    train_d_loss = train_d_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}, {:.4f}'.format(epoch, train_loss.item(), train_d_loss.item()))
    return train_loss

In [10]:
def test(epoch):
    test_loss = 0
    test_d_loss = 0    
    for x, _ in tqdm(test_loader):
        x = x.to(device)
        len_x = x.shape[0]//2
        loss = model.test({"x": x[:len_x], "x_shf": x[len_x:]})
        d_loss = tc.loss_test({"x": x[:len_x], "x_shf": x[len_x:]}) 
        test_loss += loss
        test_d_loss += d_loss
 
    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    test_d_loss = test_d_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}, {:.4f}'.format(test_loss.item(), test_d_loss.item()))
    return test_loss

In [11]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = q.sample({"x": x}, return_all=False)
        recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample_mean({"z": z_sample}).view(-1, 1, 28, 28).cpu()
        return sample

In [None]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [12]:
import pixyz
v = pixyz.__version__
writer = SummaryWriter("../runs/" + v + ".factorvae" + exp_time)

plot_dim = 8

z_sample = []
for i in range(plot_dim):
    z_batch = torch.zeros(plot_dim, z_dim)
    z_batch[:, i] = (torch.arange(plot_dim,dtype=torch.float32)*2.)/(plot_dim-1.)-1
    z_sample.append(z_batch)
z_sample = torch.cat(z_sample, dim=0).to(device)
#z_sample = 0.5 * torch.randn(64, z_dim).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

import time
start = time.time()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)

elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time)
writer.close()
writer.close()

100%|██████████| 469/469 [00:07<00:00, 66.45it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 1 Train loss: 181.1512, 1.0842


100%|██████████| 79/79 [00:01<00:00, 68.11it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 144.9058, 1.0091


100%|██████████| 469/469 [00:07<00:00, 71.57it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 2 Train loss: 133.4819, 1.1190


100%|██████████| 79/79 [00:01<00:00, 63.14it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 129.2302, 1.1242


100%|██████████| 469/469 [00:06<00:00, 68.00it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 3 Train loss: 126.8036, 1.0979


100%|██████████| 79/79 [00:01<00:00, 63.61it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 125.3009, 1.0565


100%|██████████| 469/469 [00:06<00:00, 67.65it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 4 Train loss: 124.2640, 1.0669


100%|██████████| 79/79 [00:01<00:00, 63.55it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 122.7177, 1.0479


100%|██████████| 469/469 [00:07<00:00, 65.65it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 5 Train loss: 122.4489, 1.0570


100%|██████████| 79/79 [00:01<00:00, 63.38it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 121.7393, 1.0558


100%|██████████| 469/469 [00:07<00:00, 67.23it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 6 Train loss: 121.5052, 1.0481


100%|██████████| 79/79 [00:01<00:00, 67.35it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 121.8174, 1.0260


100%|██████████| 469/469 [00:07<00:00, 66.57it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 7 Train loss: 120.4570, 1.0309


100%|██████████| 79/79 [00:01<00:00, 63.87it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.3109, 1.0448


100%|██████████| 469/469 [00:06<00:00, 67.43it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 8 Train loss: 119.6462, 1.0258


100%|██████████| 79/79 [00:01<00:00, 65.24it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 122.4285, 1.0536


100%|██████████| 469/469 [00:07<00:00, 66.64it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 9 Train loss: 119.2147, 1.0170


100%|██████████| 79/79 [00:01<00:00, 62.76it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.3293, 1.0641


100%|██████████| 469/469 [00:07<00:00, 66.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 10 Train loss: 118.3234, 1.0166


100%|██████████| 79/79 [00:01<00:00, 63.95it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 118.5090, 1.0205


100%|██████████| 469/469 [00:07<00:00, 64.84it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 11 Train loss: 117.7782, 1.0226


100%|██████████| 79/79 [00:01<00:00, 65.91it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 120.4615, 1.0190


100%|██████████| 469/469 [00:07<00:00, 66.54it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 12 Train loss: 117.9163, 1.0085


100%|██████████| 79/79 [00:01<00:00, 63.47it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.9913, 0.9945


100%|██████████| 469/469 [00:07<00:00, 66.01it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 13 Train loss: 117.1413, 0.9976


100%|██████████| 79/79 [00:01<00:00, 63.14it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.1915, 0.9868


100%|██████████| 469/469 [00:07<00:00, 65.06it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 14 Train loss: 116.8571, 1.0071


100%|██████████| 79/79 [00:01<00:00, 63.83it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.5589, 1.0282


100%|██████████| 469/469 [00:06<00:00, 67.87it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 15 Train loss: 116.7105, 0.9921


100%|██████████| 79/79 [00:01<00:00, 64.50it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.9871, 1.0114


100%|██████████| 469/469 [00:07<00:00, 64.85it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 16 Train loss: 116.4930, 0.9997


100%|██████████| 79/79 [00:01<00:00, 65.48it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 118.0580, 0.9765


100%|██████████| 469/469 [00:07<00:00, 66.25it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 17 Train loss: 115.9139, 0.9912


100%|██████████| 79/79 [00:01<00:00, 65.07it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.5378, 1.0178


100%|██████████| 469/469 [00:07<00:00, 66.87it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 18 Train loss: 115.6452, 0.9950


100%|██████████| 79/79 [00:01<00:00, 65.54it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 118.2999, 1.0503


100%|██████████| 469/469 [00:07<00:00, 64.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 19 Train loss: 115.6373, 0.9937


100%|██████████| 79/79 [00:01<00:00, 63.57it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.3356, 0.9843


100%|██████████| 469/469 [00:07<00:00, 65.66it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 20 Train loss: 115.2141, 0.9940


100%|██████████| 79/79 [00:01<00:00, 63.11it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.2319, 0.9889


100%|██████████| 469/469 [00:07<00:00, 66.36it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 21 Train loss: 115.0489, 0.9866


100%|██████████| 79/79 [00:01<00:00, 65.97it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.8251, 1.0027


100%|██████████| 469/469 [00:07<00:00, 66.68it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 22 Train loss: 114.7876, 0.9916


100%|██████████| 79/79 [00:01<00:00, 62.59it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.8664, 0.9923


100%|██████████| 469/469 [00:07<00:00, 69.34it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 23 Train loss: 114.8030, 0.9858


100%|██████████| 79/79 [00:01<00:00, 66.30it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.7814, 1.0171


100%|██████████| 469/469 [00:07<00:00, 65.82it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 24 Train loss: 114.5695, 0.9847


100%|██████████| 79/79 [00:01<00:00, 61.25it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.0533, 0.9890


100%|██████████| 469/469 [00:06<00:00, 67.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 25 Train loss: 114.2196, 0.9889


100%|██████████| 79/79 [00:01<00:00, 66.40it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.7628, 0.9677


100%|██████████| 469/469 [00:07<00:00, 66.31it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 26 Train loss: 114.1764, 0.9946


100%|██████████| 79/79 [00:01<00:00, 63.78it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 113.9489, 1.0248


100%|██████████| 469/469 [00:07<00:00, 66.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 27 Train loss: 113.9051, 0.9948


100%|██████████| 79/79 [00:01<00:00, 64.27it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.5961, 0.9989


100%|██████████| 469/469 [00:06<00:00, 68.22it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 28 Train loss: 113.8867, 0.9919


100%|██████████| 79/79 [00:01<00:00, 68.29it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.6823, 0.9906


100%|██████████| 469/469 [00:07<00:00, 66.69it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 29 Train loss: 113.2701, 0.9978


100%|██████████| 79/79 [00:01<00:00, 64.72it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.9360, 1.0189


100%|██████████| 469/469 [00:06<00:00, 67.23it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 30 Train loss: 113.5799, 0.9913


100%|██████████| 79/79 [00:01<00:00, 63.04it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.4369, 1.0093


100%|██████████| 469/469 [00:06<00:00, 67.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 31 Train loss: 113.4727, 0.9943


100%|██████████| 79/79 [00:01<00:00, 60.92it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.3559, 1.0232


100%|██████████| 469/469 [00:07<00:00, 66.66it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 32 Train loss: 113.1698, 0.9911


100%|██████████| 79/79 [00:01<00:00, 60.82it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.2681, 1.0141


100%|██████████| 469/469 [00:07<00:00, 65.64it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 33 Train loss: 113.0866, 0.9813


100%|██████████| 79/79 [00:01<00:00, 65.62it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 113.5074, 0.9977


100%|██████████| 469/469 [00:07<00:00, 65.61it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 34 Train loss: 112.9974, 0.9880


100%|██████████| 79/79 [00:01<00:00, 63.67it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.8114, 0.9754


100%|██████████| 469/469 [00:06<00:00, 67.68it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 35 Train loss: 112.8175, 0.9899


100%|██████████| 79/79 [00:01<00:00, 65.41it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.2341, 1.0051


100%|██████████| 469/469 [00:07<00:00, 66.48it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 36 Train loss: 112.8242, 0.9884


100%|██████████| 79/79 [00:01<00:00, 66.03it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.2924, 0.9872


100%|██████████| 469/469 [00:07<00:00, 66.64it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 37 Train loss: 112.7190, 0.9820


100%|██████████| 79/79 [00:01<00:00, 60.47it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.5212, 1.0167


100%|██████████| 469/469 [00:06<00:00, 67.64it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 38 Train loss: 112.5846, 0.9934


100%|██████████| 79/79 [00:01<00:00, 64.76it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.6000, 1.0236


100%|██████████| 469/469 [00:07<00:00, 66.58it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 39 Train loss: 112.2584, 0.9834


100%|██████████| 79/79 [00:01<00:00, 64.68it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.4450, 1.0096


100%|██████████| 469/469 [00:07<00:00, 66.60it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 40 Train loss: 112.6857, 0.9895


100%|██████████| 79/79 [00:01<00:00, 63.22it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.8809, 1.0296


100%|██████████| 469/469 [00:06<00:00, 67.05it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 41 Train loss: 112.1644, 0.9921


100%|██████████| 79/79 [00:01<00:00, 64.05it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.3685, 1.0487


100%|██████████| 469/469 [00:07<00:00, 65.44it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 42 Train loss: 111.9324, 0.9912


100%|██████████| 79/79 [00:01<00:00, 62.85it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.5233, 0.9952


100%|██████████| 469/469 [00:07<00:00, 66.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 43 Train loss: 112.1094, 0.9864


100%|██████████| 79/79 [00:01<00:00, 60.85it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 113.6147, 0.9870


100%|██████████| 469/469 [00:07<00:00, 66.39it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 44 Train loss: 111.8831, 0.9817


100%|██████████| 79/79 [00:01<00:00, 66.04it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.7438, 1.0124


100%|██████████| 469/469 [00:06<00:00, 68.06it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 45 Train loss: 111.9333, 0.9844


100%|██████████| 79/79 [00:01<00:00, 61.56it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.9404, 0.9940


100%|██████████| 469/469 [00:07<00:00, 66.00it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 46 Train loss: 111.8815, 0.9826


100%|██████████| 79/79 [00:01<00:00, 64.66it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.2740, 1.0239


100%|██████████| 469/469 [00:07<00:00, 65.67it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 47 Train loss: 111.9902, 0.9886


100%|██████████| 79/79 [00:01<00:00, 63.00it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.7048, 0.9989


100%|██████████| 469/469 [00:07<00:00, 64.88it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 48 Train loss: 111.6585, 0.9884


100%|██████████| 79/79 [00:01<00:00, 64.78it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.5869, 1.0134


100%|██████████| 469/469 [00:07<00:00, 65.63it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 49 Train loss: 111.7986, 0.9805


100%|██████████| 79/79 [00:01<00:00, 62.61it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.4822, 1.0070


100%|██████████| 469/469 [00:07<00:00, 64.85it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 50 Train loss: 111.7762, 0.9810


100%|██████████| 79/79 [00:01<00:00, 65.33it/s]

Test loss: 113.4714, 0.9993



