# Conditional Real NVP

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 100
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, InverseTransformedDistribution
from pixyz.flows import AffineCouplingLayer, FlowList, BatchNorm1d, ShuffleLayer, PreprocessLayer, ReverseLayer
from pixyz.models import ML

In [4]:
x_dim = 28*28
y_dim = 10
z_dim = x_dim

In [5]:
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

In [6]:
class ScaleTranslateNet(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features+y_dim, hidden_features)
        self.fc2 = nn.Linear(hidden_features, hidden_features)
        self.fc3_s = nn.Linear(hidden_features, in_features)
        self.fc3_t = nn.Linear(hidden_features, in_features)

    def forward(self, x, y):
        hidden = F.relu(self.fc2(F.relu(self.fc1(torch.cat([x, y], 1)))))
        log_s = torch.tanh(self.fc3_s(hidden))
        t = self.fc3_t(hidden)
        return log_s, t

In [7]:
# flow
flow_list = []
num_block = 5

flow_list.append(PreprocessLayer())

for i in range(num_block):
        flow_list.append(AffineCouplingLayer(in_channels=x_dim,
                                             scale_translate_net=ScaleTranslateNet(x_dim, 1028),
                                             inverse_mask=(i%2!=0)))
        
        flow_list.append(BatchNorm1d(x_dim))
        
f = FlowList(flow_list)

In [8]:
# inverse transformed distribution (z -> f^-1 -> x)
p = InverseTransformedDistribution(prior=prior, flow=f, var=["x"], cond_var=["y"])
p.to(device)

InverseTransformedDistribution(
  (prior): Normal()
  (flow): FlowList(
    (0): PreprocessLayer()
    (1): AffineCouplingLayer(
      in_features=784, mask_type=channel_wise, inverse_mask=False
      (scale_translate_net): ScaleTranslateNet(
        (fc1): Linear(in_features=794, out_features=1028, bias=True)
        (fc2): Linear(in_features=1028, out_features=1028, bias=True)
        (fc3_s): Linear(in_features=1028, out_features=784, bias=True)
        (fc3_t): Linear(in_features=1028, out_features=784, bias=True)
      )
    )
    (2): BatchNorm1d()
    (3): AffineCouplingLayer(
      in_features=784, mask_type=channel_wise, inverse_mask=True
      (scale_translate_net): ScaleTranslateNet(
        (fc1): Linear(in_features=794, out_features=1028, bias=True)
        (fc2): Linear(in_features=1028, out_features=1028, bias=True)
        (fc3_s): Linear(in_features=1028, out_features=784, bias=True)
        (fc3_t): Linear(in_features=1028, out_features=784, bias=True)
      )
    )
 

In [9]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

Distributions (for training): 
  p(x|y) 
Loss function: 
  mean(-(log p(x|y))) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


In [10]:
def train(epoch):
    train_loss = 0
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = torch.eye(10)[y].to(device)        
        loss = model.train({"x": x, "y": y})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [11]:
def test(epoch):
    test_loss = 0
    for x, y in test_loader:
        x = x.to(device)
        y = torch.eye(10)[y].to(device)
        loss = model.test({"x": x, "y": y})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [12]:
def plot_reconstrunction(x, y):
    with torch.no_grad():
        z = p.forward(x, y, compute_jacobian=False)
        recon_batch = p.inverse(z, y).view(-1, 1, 28, 28)
    
        recon = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return recon
    
def plot_image_from_latent(z, y):
    with torch.no_grad():
        sample = p.inverse(z, y).view(-1, 1, 28, 28).cpu()
        return sample
    
def plot_reconstrunction_changing_y(x, y):
    y_change = torch.eye(10)[range(7)].to(device)
    batch_dummy = torch.ones(x.size(0))[:, None].to(device)    
    recon_all = []
    
    with torch.no_grad():
        for _y in y_change:
            z = p.forward(x, y, compute_jacobian=False)
            recon_batch = p.inverse(z, batch_dummy * _y[None,:]).view(-1, 1, 28, 28)
            recon_all.append(recon_batch)
    
        recon_changing_y = torch.cat(recon_all)
        recon_changing_y = torch.cat([x.view(-1, 1, 28, 28), recon_changing_y]).cpu()
        return recon_changing_y

In [13]:
writer = SummaryWriter()

plot_number = 5

z_sample = 0.5 * torch.randn(64, z_dim).to(device)
y_sample = torch.eye(10)[[plot_number]*64].to(device)

_x, _y = iter(test_loader).next()
_x = _x.to(device)
_y = torch.eye(10)[_y].to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8], _y[:8])
    sample = plot_image_from_latent(z_sample, y_sample)
    recon_changing_y = plot_reconstrunction_changing_y(_x[:8], _y[:8])

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
    writer.add_images('Image_reconstrunction_change_y', recon_changing_y, epoch)
    
writer.close()

100%|██████████| 469/469 [00:12<00:00, 37.94it/s]

Epoch: 1 Train loss: 1440.0148



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1833.4519


100%|██████████| 469/469 [00:12<00:00, 37.55it/s]


Epoch: 2 Train loss: 1248.8079


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1548.9858


100%|██████████| 469/469 [00:12<00:00, 38.63it/s]

Epoch: 3 Train loss: 1191.3368



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1391.1505


100%|██████████| 469/469 [00:12<00:00, 38.90it/s]


Epoch: 4 Train loss: 1156.2378


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1410.8007


100%|██████████| 469/469 [00:12<00:00, 39.53it/s]


Epoch: 5 Train loss: 1133.8752


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1391.0446


100%|██████████| 469/469 [00:12<00:00, 38.70it/s]


Epoch: 6 Train loss: 1116.7030


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1383.0996


100%|██████████| 469/469 [00:12<00:00, 38.57it/s]


Epoch: 7 Train loss: 1104.0450


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1545.8665


100%|██████████| 469/469 [00:12<00:00, 38.35it/s]

Epoch: 8 Train loss: 1091.9137



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1247.1461


100%|██████████| 469/469 [00:12<00:00, 38.38it/s]

Epoch: 9 Train loss: 1082.5421



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1342.3170


100%|██████████| 469/469 [00:11<00:00, 39.53it/s]


Epoch: 10 Train loss: 1074.2792


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1298.7635


100%|██████████| 469/469 [00:12<00:00, 38.81it/s]


Epoch: 11 Train loss: 1067.4585


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1285.8353


100%|██████████| 469/469 [00:12<00:00, 39.31it/s]


Epoch: 12 Train loss: 1060.3219


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1250.7710


100%|██████████| 469/469 [00:12<00:00, 38.58it/s]

Epoch: 13 Train loss: 1055.1023



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1268.0629


100%|██████████| 469/469 [00:12<00:00, 37.50it/s]


Epoch: 14 Train loss: 1049.9446


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1395.5054


100%|██████████| 469/469 [00:11<00:00, 39.43it/s]


Epoch: 15 Train loss: 1045.6042


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1303.2709


100%|██████████| 469/469 [00:05<00:00, 78.71it/s]


Epoch: 16 Train loss: 1041.2090


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1764.7655


100%|██████████| 469/469 [00:05<00:00, 79.43it/s]


Epoch: 17 Train loss: 1036.2295


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1210.0487


100%|██████████| 469/469 [00:09<00:00, 39.35it/s]


Epoch: 18 Train loss: 1033.3136


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1319.8950


100%|██████████| 469/469 [00:12<00:00, 38.76it/s]

Epoch: 19 Train loss: 1029.2192



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1283.9238


100%|██████████| 469/469 [00:11<00:00, 39.74it/s]


Epoch: 20 Train loss: 1025.9790


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1279.2837


100%|██████████| 469/469 [00:12<00:00, 38.82it/s]

Epoch: 21 Train loss: 1023.0109



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1322.1229


100%|██████████| 469/469 [00:12<00:00, 39.58it/s]


Epoch: 22 Train loss: 1019.4136


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1394.4746


100%|██████████| 469/469 [00:12<00:00, 38.51it/s]

Epoch: 23 Train loss: 1018.3341



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1253.1842


100%|██████████| 469/469 [00:12<00:00, 38.91it/s]

Epoch: 24 Train loss: 1015.1594



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1348.2720


100%|██████████| 469/469 [00:11<00:00, 39.17it/s]


Epoch: 25 Train loss: 1012.3538


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1232.3754


100%|██████████| 469/469 [00:12<00:00, 38.66it/s]

Epoch: 26 Train loss: 1010.0950



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1243.5073


100%|██████████| 469/469 [00:12<00:00, 39.59it/s]


Epoch: 27 Train loss: 1007.6366


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1273.8549


100%|██████████| 469/469 [00:12<00:00, 38.87it/s]

Epoch: 28 Train loss: 1004.8156



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1367.6558


100%|██████████| 469/469 [00:12<00:00, 38.73it/s]

Epoch: 29 Train loss: 1003.4529



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1482.9918


100%|██████████| 469/469 [00:12<00:00, 38.81it/s]


Epoch: 30 Train loss: 1001.4311


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1279.9388


100%|██████████| 469/469 [00:12<00:00, 38.86it/s]


Epoch: 31 Train loss: 999.9359


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1251.2716


100%|██████████| 469/469 [00:12<00:00, 39.12it/s]


Epoch: 32 Train loss: 997.7358


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1235.7163


100%|██████████| 469/469 [00:12<00:00, 38.33it/s]

Epoch: 33 Train loss: 996.8733



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1326.5490


100%|██████████| 469/469 [00:12<00:00, 38.32it/s]

Epoch: 34 Train loss: 995.1768



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1238.3529


100%|██████████| 469/469 [00:12<00:00, 38.69it/s]

Epoch: 35 Train loss: 993.3559



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1180.2397


100%|██████████| 469/469 [00:12<00:00, 39.14it/s]


Epoch: 36 Train loss: 990.8202


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1233.0723


100%|██████████| 469/469 [00:12<00:00, 38.26it/s]

Epoch: 37 Train loss: 989.6333



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1195.0106


100%|██████████| 469/469 [00:12<00:00, 39.48it/s]


Epoch: 38 Train loss: 988.6293


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1198.7180


100%|██████████| 469/469 [00:12<00:00, 37.99it/s]


Epoch: 39 Train loss: 986.2447


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1306.3074


100%|██████████| 469/469 [00:12<00:00, 38.38it/s]


Epoch: 40 Train loss: 985.7638


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1187.5837


100%|██████████| 469/469 [00:10<00:00, 44.15it/s]

Epoch: 41 Train loss: 984.2593



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1272.6436


100%|██████████| 469/469 [00:07<00:00, 64.23it/s]

Epoch: 42 Train loss: 983.3307



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1216.0111


100%|██████████| 469/469 [00:11<00:00, 39.11it/s]


Epoch: 43 Train loss: 981.5037


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1227.2412


100%|██████████| 469/469 [00:06<00:00, 78.14it/s]

Epoch: 44 Train loss: 980.9578



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1279.8129


100%|██████████| 469/469 [00:12<00:00, 39.55it/s]


Epoch: 45 Train loss: 980.3894


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1388.0955


100%|██████████| 469/469 [00:12<00:00, 39.29it/s]


Epoch: 46 Train loss: 978.4866


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1454.9409


100%|██████████| 469/469 [00:12<00:00, 38.14it/s]


Epoch: 47 Train loss: 977.7287


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1276.6620


100%|██████████| 469/469 [00:12<00:00, 38.57it/s]


Epoch: 48 Train loss: 976.1412


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1497.9027


100%|██████████| 469/469 [00:12<00:00, 38.54it/s]

Epoch: 49 Train loss: 976.7797



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1379.7672


100%|██████████| 469/469 [00:12<00:00, 38.68it/s]


Epoch: 50 Train loss: 974.3002


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1273.1782


100%|██████████| 469/469 [00:12<00:00, 38.38it/s]


Epoch: 51 Train loss: 973.2398


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1660.4736


100%|██████████| 469/469 [00:11<00:00, 39.56it/s]


Epoch: 52 Train loss: 972.3604


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1304.1445


100%|██████████| 469/469 [00:12<00:00, 39.21it/s]


Epoch: 53 Train loss: 971.5209


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1278.1775


100%|██████████| 469/469 [00:12<00:00, 38.51it/s]

Epoch: 54 Train loss: 970.6526



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1254.9154


100%|██████████| 469/469 [00:12<00:00, 39.55it/s]


Epoch: 55 Train loss: 970.8807


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1316.0061


100%|██████████| 469/469 [00:12<00:00, 38.49it/s]

Epoch: 56 Train loss: 969.0471



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1420.2142


100%|██████████| 469/469 [00:12<00:00, 39.08it/s]


Epoch: 57 Train loss: 968.4023


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1218.2583


100%|██████████| 469/469 [00:12<00:00, 37.96it/s]


Epoch: 58 Train loss: 967.7509


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1327.7638


100%|██████████| 469/469 [00:12<00:00, 38.72it/s]

Epoch: 59 Train loss: 967.0488



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1331.9484


100%|██████████| 469/469 [00:12<00:00, 38.44it/s]

Epoch: 60 Train loss: 965.5513



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1411.9840


100%|██████████| 469/469 [00:12<00:00, 39.54it/s]


Epoch: 61 Train loss: 965.4755


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1272.6141


100%|██████████| 469/469 [00:12<00:00, 38.55it/s]


Epoch: 62 Train loss: 964.4851


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1247.8740


100%|██████████| 469/469 [00:12<00:00, 38.37it/s]

Epoch: 63 Train loss: 964.5583



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2129.6172


100%|██████████| 469/469 [00:12<00:00, 37.98it/s]


Epoch: 64 Train loss: 963.1701


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1333.3923


100%|██████████| 469/469 [00:12<00:00, 39.42it/s]


Epoch: 65 Train loss: 963.1232


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1312.5775


100%|██████████| 469/469 [00:12<00:00, 37.41it/s]


Epoch: 66 Train loss: 961.7045


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1491.6874


100%|██████████| 469/469 [00:11<00:00, 38.97it/s]


Epoch: 67 Train loss: 961.3983


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1565.2987


100%|██████████| 469/469 [00:12<00:00, 38.50it/s]


Epoch: 68 Train loss: 960.8768


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1406.0004


100%|██████████| 469/469 [00:12<00:00, 38.42it/s]


Epoch: 69 Train loss: 959.6415


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1426.2333


100%|██████████| 469/469 [00:11<00:00, 39.26it/s]

Epoch: 70 Train loss: 959.4510



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1410.1158


100%|██████████| 469/469 [00:12<00:00, 38.60it/s]

Epoch: 71 Train loss: 958.5823



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1516.9108


100%|██████████| 469/469 [00:10<00:00, 38.76it/s]


Epoch: 72 Train loss: 959.0033


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1251.5145


100%|██████████| 469/469 [00:12<00:00, 38.18it/s]

Epoch: 73 Train loss: 957.8531



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1206.0925


100%|██████████| 469/469 [00:12<00:00, 38.60it/s]


Epoch: 74 Train loss: 957.2663


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1482.2386


100%|██████████| 469/469 [00:12<00:00, 38.25it/s]


Epoch: 75 Train loss: 957.0587


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1386.5692


100%|██████████| 469/469 [00:12<00:00, 38.39it/s]


Epoch: 76 Train loss: 955.7079


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1492.9717


100%|██████████| 469/469 [00:12<00:00, 38.07it/s]


Epoch: 77 Train loss: 955.5966


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1372.9685


100%|██████████| 469/469 [00:12<00:00, 37.96it/s]

Epoch: 78 Train loss: 955.0965



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1590.2822


100%|██████████| 469/469 [00:12<00:00, 38.95it/s]

Epoch: 79 Train loss: 954.3845



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1557.2008


100%|██████████| 469/469 [00:07<00:00, 39.22it/s]


Epoch: 80 Train loss: 954.0950


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1385.7250


100%|██████████| 469/469 [00:12<00:00, 39.23it/s]


Epoch: 81 Train loss: 952.6962


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1727.8306


100%|██████████| 469/469 [00:11<00:00, 40.14it/s]


Epoch: 82 Train loss: 953.4655


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1313.8521


100%|██████████| 469/469 [00:12<00:00, 38.26it/s]

Epoch: 83 Train loss: 953.2325



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2301.3184


100%|██████████| 469/469 [00:12<00:00, 38.36it/s]

Epoch: 84 Train loss: 952.3165



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1368.8662


100%|██████████| 469/469 [00:12<00:00, 37.58it/s]

Epoch: 85 Train loss: 951.4161



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1479.1102


100%|██████████| 469/469 [00:12<00:00, 38.53it/s]

Epoch: 86 Train loss: 950.4781



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1405.6063


100%|██████████| 469/469 [00:12<00:00, 38.43it/s]

Epoch: 87 Train loss: 950.2593



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1775.1532


100%|██████████| 469/469 [00:12<00:00, 37.85it/s]

Epoch: 88 Train loss: 949.6894



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1595.4094


100%|██████████| 469/469 [00:12<00:00, 40.15it/s]


Epoch: 89 Train loss: 951.2747


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1352.5880


100%|██████████| 469/469 [00:12<00:00, 38.46it/s]


Epoch: 90 Train loss: 949.6974


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1557.4431


100%|██████████| 469/469 [00:12<00:00, 38.49it/s]

Epoch: 91 Train loss: 948.8833



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1360.2599


100%|██████████| 469/469 [00:12<00:00, 38.58it/s]

Epoch: 92 Train loss: 948.6522



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1356.9211


100%|██████████| 469/469 [00:12<00:00, 37.57it/s]

Epoch: 93 Train loss: 948.9141



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1245.1058


100%|██████████| 469/469 [00:12<00:00, 39.66it/s]


Epoch: 94 Train loss: 947.8754


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2274.6111


100%|██████████| 469/469 [00:12<00:00, 36.26it/s]


Epoch: 95 Train loss: 946.5869


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1458.9440


100%|██████████| 469/469 [00:12<00:00, 38.15it/s]

Epoch: 96 Train loss: 946.7527



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1597.2271


100%|██████████| 469/469 [00:12<00:00, 38.52it/s]

Epoch: 97 Train loss: 946.8162



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1487.8665


100%|██████████| 469/469 [00:12<00:00, 38.29it/s]

Epoch: 98 Train loss: 945.7451



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1362.2968


100%|██████████| 469/469 [00:12<00:00, 38.65it/s]

Epoch: 99 Train loss: 945.7551



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1527.3599


100%|██████████| 469/469 [00:12<00:00, 38.51it/s]


Epoch: 100 Train loss: 945.5173
Test loss: 1348.1760
