# FactorVAE

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, Bernoulli, Deterministic
from pixyz.losses import KullbackLeibler, CrossEntropy, AdversarialKullbackLeibler
from pixyz.models import Model

In [4]:
x_dim = 784
z_dim = 8


# inference model q(z|x)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# generative model p(x|z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

p = Generator()
q = Inference()

In [5]:
class InferenceShuffleDim(Deterministic):
    def __init__(self):
        super(InferenceShuffleDim, self).__init__(cond_var=["x_"], var=["z"], name="q_shuffle")

    def forward(self, x_):
        z = q.sample({"x": x_}, return_all=False)["z"]
        return {"z": z[:,torch.randperm(z.shape[1])]}
    
q_shuffle = InferenceShuffleDim()

In [6]:
class Discriminator(Deterministic):
    def __init__(self):
        super(Discriminator, self).__init__(cond_var=["z"], var=["t"], name="d")

        self.model = nn.Sequential(
            nn.Linear(z_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        t = self.model(z)
        return {"t": t}

In [7]:
p = Generator()
q = Inference()
d = Discriminator()

p.to(device)
q.to(device)
d.to(device)

print(p)
print(q)
print(q_shuffle)
print(d)

Distribution:
  p(x|z) (Bernoulli)
Network architecture:
  Generator(
    (fc1): Linear(in_features=8, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )
Distribution:
  q(z|x) (Normal)
Network architecture:
  Inference(
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc31): Linear(in_features=512, out_features=8, bias=True)
    (fc32): Linear(in_features=512, out_features=8, bias=True)
  )
Distribution:
  q_shuffle(z|x_) (Deterministic)
Network architecture:
  InferenceShuffleDim()
Distribution:
  d(t|z) (Deterministic)
Network architecture:
  Discriminator(
    (model): Sequential(
      (0): Linear(in_features=8, out_features=512, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace)
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.

In [8]:
reconst = CrossEntropy(q, p)
kl = KullbackLeibler(q, prior)
tc = AdversarialKullbackLeibler(q, q_shuffle, discriminator=d, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
loss_cls = reconst.mean() + kl.mean() + 10*tc
print(loss_cls)

mean(-E_q(z|x)[log p(x|z)]) + mean(KL[q(z|x)||p_prior(z)]) + mean(AdversarialKL[q(z|x)||q_shuffle(z|x_)]) * 10


In [9]:
model = Model(loss_cls, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

Distributions (for training): 
  p(x|z), q(z|x) 
Loss function: 
  mean(-E_q(z|x)[log p(x|z)]) + mean(KL[q(z|x)||p_prior(z)]) + mean(AdversarialKL[q(z|x)||q_shuffle(z|x_)]) * 10 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


In [10]:
def train(epoch):
    train_loss = 0
    train_d_loss = 0    
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        len_x = x.shape[0]//2
        loss = model.train({"x": x[:len_x], "x_": x[len_x:]})
        d_loss = tc.train({"x": x[:len_x], "x_": x[len_x:]})
        train_loss += loss
        train_d_loss += d_loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    train_d_loss = train_d_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}, {:.4f}'.format(epoch, train_loss.item(), train_d_loss.item()))
    return train_loss

In [11]:
def test(epoch):
    test_loss = 0
    test_d_loss = 0    
    for x, _ in tqdm(test_loader):
        x = x.to(device)
        len_x = x.shape[0]//2
        loss = model.test({"x": x[:len_x], "x_": x[len_x:]})
        d_loss = tc.test({"x": x[:len_x], "x_": x[len_x:]})
        test_loss += loss
        test_d_loss += d_loss
 
    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    test_d_loss = test_d_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}, {:.4f}'.format(test_loss.item(), test_d_loss.item()))
    return test_loss

In [12]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = q.sample({"x": x}, return_all=False)
        recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample_mean({"z": z_sample}).view(-1, 1, 28, 28).cpu()
        return sample

In [13]:
writer = SummaryWriter()

plot_dim = 8

z_sample = []
for i in range(plot_dim):
    z_batch = torch.zeros(plot_dim, z_dim)
    z_batch[:, i] = (torch.arange(plot_dim,dtype=torch.float32)*2.)/(plot_dim-1.)-1
    z_sample.append(z_batch)
z_sample = torch.cat(z_sample, dim=0).to(device)
#z_sample = 0.5 * torch.randn(64, z_dim).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:07<00:00, 61.54it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 1 Train loss: 180.5542, 1.0638


100%|██████████| 79/79 [00:00<00:00, 107.91it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 145.8692, 0.9874


100%|██████████| 469/469 [00:06<00:00, 68.13it/s] 
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 2 Train loss: 133.6951, 1.1348


100%|██████████| 79/79 [00:00<00:00, 101.50it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 129.5038, 1.1545


100%|██████████| 469/469 [00:07<00:00, 65.40it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 3 Train loss: 126.1546, 1.1131


100%|██████████| 79/79 [00:00<00:00, 104.26it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 123.1171, 1.0773


100%|██████████| 469/469 [00:06<00:00, 68.03it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 4 Train loss: 123.4547, 1.0814


100%|██████████| 79/79 [00:00<00:00, 102.65it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 120.9322, 1.0498


100%|██████████| 469/469 [00:06<00:00, 75.33it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 5 Train loss: 122.0771, 1.0621


100%|██████████| 79/79 [00:00<00:00, 105.86it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 122.2007, 1.0552


100%|██████████| 469/469 [00:07<00:00, 65.50it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 6 Train loss: 120.7251, 1.0487


100%|██████████| 79/79 [00:00<00:00, 105.89it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.8272, 1.0564


100%|██████████| 469/469 [00:05<00:00, 80.79it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 7 Train loss: 119.7787, 1.0386


100%|██████████| 79/79 [00:00<00:00, 108.34it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 118.7469, 1.0717


100%|██████████| 469/469 [00:06<00:00, 67.26it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 8 Train loss: 119.0261, 1.0215


100%|██████████| 79/79 [00:00<00:00, 103.14it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 118.5538, 1.0338


100%|██████████| 469/469 [00:07<00:00, 62.61it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 9 Train loss: 118.7542, 1.0152


100%|██████████| 79/79 [00:00<00:00, 98.54it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.4869, 1.0398


100%|██████████| 469/469 [00:07<00:00, 63.40it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 10 Train loss: 118.2731, 1.0111


100%|██████████| 79/79 [00:00<00:00, 104.34it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.5805, 1.0535


100%|██████████| 469/469 [00:06<00:00, 64.82it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 11 Train loss: 117.7593, 1.0089


100%|██████████| 79/79 [00:00<00:00, 90.83it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.2416, 1.0148


100%|██████████| 469/469 [00:07<00:00, 61.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 12 Train loss: 117.0589, 1.0050


100%|██████████| 79/79 [00:00<00:00, 102.00it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.1960, 1.0080


100%|██████████| 469/469 [00:07<00:00, 64.48it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 13 Train loss: 116.9724, 1.0075


100%|██████████| 79/79 [00:00<00:00, 98.11it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.3524, 1.0083


100%|██████████| 469/469 [00:07<00:00, 66.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 14 Train loss: 116.3094, 0.9993


100%|██████████| 79/79 [00:00<00:00, 93.55it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.5599, 1.0041


100%|██████████| 469/469 [00:07<00:00, 63.68it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 15 Train loss: 116.3988, 0.9929


100%|██████████| 79/79 [00:00<00:00, 100.83it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 118.3677, 0.9974


100%|██████████| 469/469 [00:07<00:00, 64.97it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 16 Train loss: 115.7893, 1.0007


100%|██████████| 79/79 [00:00<00:00, 98.79it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.8238, 0.9697


100%|██████████| 469/469 [00:06<00:00, 67.85it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 17 Train loss: 115.8946, 0.9957


100%|██████████| 79/79 [00:00<00:00, 96.96it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.7156, 1.0032


100%|██████████| 469/469 [00:06<00:00, 68.99it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 18 Train loss: 115.1388, 0.9946


100%|██████████| 79/79 [00:00<00:00, 82.34it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.6603, 0.9691


100%|██████████| 469/469 [00:07<00:00, 57.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 19 Train loss: 115.1911, 0.9912


100%|██████████| 79/79 [00:00<00:00, 108.66it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.1207, 0.9806


100%|██████████| 469/469 [00:06<00:00, 59.96it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 20 Train loss: 114.9707, 0.9893


100%|██████████| 79/79 [00:00<00:00, 98.00it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.5723, 0.9949


100%|██████████| 469/469 [00:07<00:00, 63.26it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 21 Train loss: 114.9988, 0.9812


100%|██████████| 79/79 [00:00<00:00, 96.77it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.7190, 0.9853


100%|██████████| 469/469 [00:07<00:00, 65.01it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 22 Train loss: 114.4609, 0.9893


100%|██████████| 79/79 [00:00<00:00, 92.22it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.6560, 0.9855


100%|██████████| 469/469 [00:06<00:00, 96.67it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 23 Train loss: 114.5087, 0.9792


100%|██████████| 79/79 [00:00<00:00, 97.57it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 118.5699, 1.0024


100%|██████████| 469/469 [00:07<00:00, 62.46it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 24 Train loss: 114.3928, 0.9827


100%|██████████| 79/79 [00:00<00:00, 96.16it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.1338, 1.0209


100%|██████████| 469/469 [00:07<00:00, 66.87it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 25 Train loss: 114.1577, 0.9830


100%|██████████| 79/79 [00:00<00:00, 93.00it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.6730, 0.9760


100%|██████████| 469/469 [00:07<00:00, 66.05it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 26 Train loss: 114.0882, 0.9781


100%|██████████| 79/79 [00:00<00:00, 77.26it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 113.7667, 0.9672


100%|██████████| 469/469 [00:06<00:00, 68.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 27 Train loss: 114.0684, 0.9829


100%|██████████| 79/79 [00:00<00:00, 75.13it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.5416, 0.9737


100%|██████████| 469/469 [00:06<00:00, 69.35it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 28 Train loss: 113.5792, 0.9865


100%|██████████| 79/79 [00:00<00:00, 94.85it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.6769, 1.0209


100%|██████████| 469/469 [00:06<00:00, 69.07it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 29 Train loss: 113.4424, 0.9795


100%|██████████| 79/79 [00:00<00:00, 96.54it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.4973, 1.0418


100%|██████████| 469/469 [00:07<00:00, 66.49it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 30 Train loss: 113.4541, 0.9729


100%|██████████| 79/79 [00:00<00:00, 98.90it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.7696, 0.9752


100%|██████████| 469/469 [00:07<00:00, 66.65it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 31 Train loss: 113.3450, 0.9826


100%|██████████| 79/79 [00:00<00:00, 95.68it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.3175, 0.9621


100%|██████████| 469/469 [00:06<00:00, 71.76it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 32 Train loss: 113.2344, 0.9731


100%|██████████| 79/79 [00:00<00:00, 111.39it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.6436, 0.9993


100%|██████████| 469/469 [00:07<00:00, 66.42it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 33 Train loss: 112.8494, 0.9841


100%|██████████| 79/79 [00:00<00:00, 101.51it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.7649, 0.9904


100%|██████████| 469/469 [00:05<00:00, 90.47it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 34 Train loss: 113.1741, 0.9781


100%|██████████| 79/79 [00:00<00:00, 107.11it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.6636, 0.9871


100%|██████████| 469/469 [00:06<00:00, 68.22it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 35 Train loss: 112.7347, 0.9672


100%|██████████| 79/79 [00:00<00:00, 99.45it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.5543, 0.9814


100%|██████████| 469/469 [00:06<00:00, 67.47it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 36 Train loss: 113.1169, 0.9730


100%|██████████| 79/79 [00:00<00:00, 101.80it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 113.0064, 0.9793


100%|██████████| 469/469 [00:07<00:00, 60.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 37 Train loss: 112.7333, 0.9697


100%|██████████| 79/79 [00:00<00:00, 96.83it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.2284, 1.0181


100%|██████████| 469/469 [00:06<00:00, 69.96it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 38 Train loss: 112.5935, 0.9745


100%|██████████| 79/79 [00:00<00:00, 100.16it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 112.6215, 0.9842


100%|██████████| 469/469 [00:07<00:00, 66.33it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 39 Train loss: 112.3932, 0.9795


100%|██████████| 79/79 [00:00<00:00, 101.38it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.8184, 0.9727


100%|██████████| 469/469 [00:07<00:00, 65.33it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 40 Train loss: 112.0739, 0.9789


100%|██████████| 79/79 [00:00<00:00, 109.12it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 113.4581, 0.9433


100%|██████████| 469/469 [00:06<00:00, 62.79it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 41 Train loss: 112.3429, 0.9740


100%|██████████| 79/79 [00:00<00:00, 94.06it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.2240, 1.0076


100%|██████████| 469/469 [00:07<00:00, 60.14it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 42 Train loss: 111.9027, 0.9828


100%|██████████| 79/79 [00:00<00:00, 104.60it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 116.9396, 0.9930


100%|██████████| 469/469 [00:07<00:00, 65.40it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 43 Train loss: 112.2831, 0.9726


100%|██████████| 79/79 [00:00<00:00, 105.14it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.1977, 0.9625


100%|██████████| 469/469 [00:07<00:00, 63.54it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 44 Train loss: 112.0154, 0.9712


100%|██████████| 79/79 [00:00<00:00, 89.13it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.3332, 0.9981


100%|██████████| 469/469 [00:07<00:00, 63.63it/s] 
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 45 Train loss: 112.2970, 0.9684


100%|██████████| 79/79 [00:00<00:00, 107.80it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.2651, 0.9858


100%|██████████| 469/469 [00:06<00:00, 67.23it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 46 Train loss: 112.0269, 0.9798


100%|██████████| 79/79 [00:00<00:00, 93.86it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.3706, 1.0083


100%|██████████| 469/469 [00:06<00:00, 71.33it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 47 Train loss: 111.9729, 0.9672


100%|██████████| 79/79 [00:00<00:00, 101.07it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.5581, 0.9844


100%|██████████| 469/469 [00:07<00:00, 76.00it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 48 Train loss: 111.9358, 0.9668


100%|██████████| 79/79 [00:00<00:00, 92.45it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 112.8374, 1.0277


100%|██████████| 469/469 [00:07<00:00, 64.12it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 49 Train loss: 111.7134, 0.9731


100%|██████████| 79/79 [00:00<00:00, 99.12it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.9733, 0.9605


100%|██████████| 469/469 [00:06<00:00, 68.00it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 50 Train loss: 111.5306, 0.9726


100%|██████████| 79/79 [00:00<00:00, 106.55it/s]

Test loss: 115.3599, 0.9732



