# Real NVP

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 100
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 4, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, InverseTransformedDistribution
from pixyz.flows import AffineCouplingLayer, FlowList, BatchNorm1d, ShuffleLayer, PreprocessLayer, ReverseLayer
from pixyz.models import ML

In [4]:
x_dim = 28*28
z_dim = x_dim

In [5]:
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

In [6]:
class ScaleTranslateNet(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, hidden_features)
        self.fc3_s = nn.Linear(hidden_features, in_features)
        self.fc3_t = nn.Linear(hidden_features, in_features)

    def forward(self, x):
        hidden = F.relu(self.fc2(F.relu(self.fc1(x))))
        log_s = torch.tanh(self.fc3_s(hidden))
        t = self.fc3_t(hidden)
        return log_s, t

In [7]:
# flow
flow_list = []
num_block = 5

flow_list.append(PreprocessLayer())

for i in range(num_block):
        flow_list.append(AffineCouplingLayer(in_channels=x_dim,
                                             scale_translate_net=ScaleTranslateNet(x_dim, 1028),
                                             inverse_mask=(i%2!=0)))
        
        flow_list.append(BatchNorm1d(x_dim))
        
f = FlowList(flow_list)

In [8]:
# inverse transformed distribution (z -> f^-1 -> x)
p = InverseTransformedDistribution(prior=prior, flow=f, var=["x"])
p.to(device)

InverseTransformedDistribution(
  (prior): Normal()
  (flow): FlowList(
    (0): PreprocessLayer()
    (1): AffineCouplingLayer(
      in_features=784, mask_type=channel_wise, inverse_mask=False
      (scale_translate_net): ScaleTranslateNet(
        (fc1): Linear(in_features=784, out_features=1028, bias=True)
        (fc2): Linear(in_features=1028, out_features=1028, bias=True)
        (fc3_s): Linear(in_features=1028, out_features=784, bias=True)
        (fc3_t): Linear(in_features=1028, out_features=784, bias=True)
      )
    )
    (2): BatchNorm1d()
    (3): AffineCouplingLayer(
      in_features=784, mask_type=channel_wise, inverse_mask=True
      (scale_translate_net): ScaleTranslateNet(
        (fc1): Linear(in_features=784, out_features=1028, bias=True)
        (fc2): Linear(in_features=1028, out_features=1028, bias=True)
        (fc3_s): Linear(in_features=1028, out_features=784, bias=True)
        (fc3_t): Linear(in_features=1028, out_features=784, bias=True)
      )
    )
 

In [9]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

Distributions (for training): 
  p(x) 
Loss function: 
  mean(-(log p(x))) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


In [10]:
def train(epoch):
    train_loss = 0
    
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss = model.train({"x": x})
        train_loss += loss

    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [11]:
def test(epoch):
    test_loss = 0
    for x, _ in test_loader:
        x = x.to(device)
        loss = model.test({"x": x})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [12]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = p.forward(x, compute_jacobian=False)
        recon_batch = p.inverse(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.inverse(z_sample).view(-1, 1, 28, 28).cpu()
        return sample

In [13]:
writer = SummaryWriter()

z_sample = torch.randn(64, z_dim).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:12<00:00, 36.38it/s]

Epoch: 1 Train loss: 1445.3479



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1878.3894


100%|██████████| 469/469 [00:11<00:00, 40.10it/s]


Epoch: 2 Train loss: 1254.6447


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1600.0464


100%|██████████| 469/469 [00:13<00:00, 34.87it/s]

Epoch: 3 Train loss: 1196.9751



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1533.6735


100%|██████████| 469/469 [00:13<00:00, 34.18it/s]

Epoch: 4 Train loss: 1163.2500



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1448.3883


100%|██████████| 469/469 [00:11<00:00, 39.21it/s]


Epoch: 5 Train loss: 1139.7772


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1336.7216


100%|██████████| 469/469 [00:13<00:00, 34.00it/s]

Epoch: 6 Train loss: 1124.7460



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1412.4645


100%|██████████| 469/469 [00:13<00:00, 34.20it/s]

Epoch: 7 Train loss: 1110.5399



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1574.0319


100%|██████████| 469/469 [00:13<00:00, 33.91it/s]

Epoch: 8 Train loss: 1099.4238



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1327.1321


100%|██████████| 469/469 [00:13<00:00, 33.95it/s]


Epoch: 9 Train loss: 1089.2649


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1366.3943


100%|██████████| 469/469 [00:12<00:00, 36.73it/s]

Epoch: 10 Train loss: 1081.8748



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1291.3032


100%|██████████| 469/469 [00:13<00:00, 34.27it/s]

Epoch: 11 Train loss: 1074.2906



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1313.3351


100%|██████████| 469/469 [00:13<00:00, 35.15it/s]

Epoch: 12 Train loss: 1068.7515



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1308.8180


100%|██████████| 469/469 [00:13<00:00, 36.05it/s]

Epoch: 13 Train loss: 1063.7410



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1257.8849


100%|██████████| 469/469 [00:13<00:00, 34.70it/s]


Epoch: 14 Train loss: 1058.1184


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1243.9095


100%|██████████| 469/469 [00:11<00:00, 39.81it/s]

Epoch: 15 Train loss: 1054.0980



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1261.7003


100%|██████████| 469/469 [00:12<00:00, 36.13it/s]


Epoch: 16 Train loss: 1049.1976


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1417.4336


100%|██████████| 469/469 [00:13<00:00, 34.90it/s]

Epoch: 17 Train loss: 1045.9222



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1344.7759


100%|██████████| 469/469 [00:12<00:00, 37.11it/s]

Epoch: 18 Train loss: 1041.7664



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1279.8588


100%|██████████| 469/469 [00:12<00:00, 37.30it/s]

Epoch: 19 Train loss: 1038.5897



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1287.4314


100%|██████████| 469/469 [00:13<00:00, 34.94it/s]

Epoch: 20 Train loss: 1035.6672



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1297.0714


100%|██████████| 469/469 [00:12<00:00, 36.27it/s]


Epoch: 21 Train loss: 1033.2083


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1262.0453


100%|██████████| 469/469 [00:12<00:00, 36.50it/s]

Epoch: 22 Train loss: 1030.2413



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1213.4402


100%|██████████| 469/469 [00:12<00:00, 36.61it/s]

Epoch: 23 Train loss: 1026.8342



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1263.3401


100%|██████████| 469/469 [00:14<00:00, 33.00it/s]

Epoch: 24 Train loss: 1024.5668



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1282.0740


100%|██████████| 469/469 [00:12<00:00, 38.97it/s]

Epoch: 25 Train loss: 1023.0452



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1229.5448


100%|██████████| 469/469 [00:12<00:00, 37.41it/s]


Epoch: 26 Train loss: 1019.0698


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1259.1876


100%|██████████| 469/469 [00:13<00:00, 35.31it/s]

Epoch: 27 Train loss: 1017.9515



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1350.6390


100%|██████████| 469/469 [00:12<00:00, 36.38it/s]

Epoch: 28 Train loss: 1014.5969



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1294.6136


100%|██████████| 469/469 [00:12<00:00, 39.20it/s]


Epoch: 29 Train loss: 1013.5076


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1319.9265


100%|██████████| 469/469 [00:13<00:00, 35.19it/s]


Epoch: 30 Train loss: 1011.3450


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1503.6522


100%|██████████| 469/469 [00:12<00:00, 37.16it/s]


Epoch: 31 Train loss: 1008.6239


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1944.2272


100%|██████████| 469/469 [00:14<00:00, 33.36it/s]

Epoch: 32 Train loss: 1007.8955



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1508.7953


100%|██████████| 469/469 [00:13<00:00, 34.79it/s]

Epoch: 33 Train loss: 1006.2440



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1305.7007


100%|██████████| 469/469 [00:13<00:00, 36.27it/s]


Epoch: 34 Train loss: 1003.8425


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1307.3944


100%|██████████| 469/469 [00:12<00:00, 38.81it/s]


Epoch: 35 Train loss: 1003.3040


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1534.4933


100%|██████████| 469/469 [00:14<00:00, 35.56it/s]


Epoch: 36 Train loss: 1001.8641


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1191.2786


100%|██████████| 469/469 [00:13<00:00, 34.92it/s]

Epoch: 37 Train loss: 1000.1600



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1266.1372


100%|██████████| 469/469 [00:12<00:00, 38.59it/s]


Epoch: 38 Train loss: 998.6024


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1239.5865


100%|██████████| 469/469 [00:13<00:00, 35.93it/s]

Epoch: 39 Train loss: 996.3425



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1699.5035


100%|██████████| 469/469 [00:14<00:00, 35.00it/s]


Epoch: 40 Train loss: 995.4576


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1305.8623


100%|██████████| 469/469 [00:12<00:00, 36.30it/s]


Epoch: 41 Train loss: 994.3260


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1358.6189


100%|██████████| 469/469 [00:12<00:00, 39.73it/s]


Epoch: 42 Train loss: 994.2109


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1258.1831


100%|██████████| 469/469 [00:14<00:00, 33.41it/s]

Epoch: 43 Train loss: 992.1832



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1343.3674


100%|██████████| 469/469 [00:12<00:00, 38.63it/s]

Epoch: 44 Train loss: 991.1606



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1225.7939


100%|██████████| 469/469 [00:13<00:00, 35.58it/s]

Epoch: 45 Train loss: 990.1309



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1383.2706


100%|██████████| 469/469 [00:12<00:00, 36.81it/s]


Epoch: 46 Train loss: 988.4579


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1417.9615


100%|██████████| 469/469 [00:14<00:00, 35.40it/s]


Epoch: 47 Train loss: 987.3669


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1221.9420


100%|██████████| 469/469 [00:13<00:00, 40.65it/s]


Epoch: 48 Train loss: 986.8479


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2125.4482


100%|██████████| 469/469 [00:13<00:00, 34.78it/s]

Epoch: 49 Train loss: 985.4681



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1403.4011


100%|██████████| 469/469 [00:13<00:00, 35.11it/s]


Epoch: 50 Train loss: 984.2176


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1273.0194


100%|██████████| 469/469 [00:13<00:00, 34.98it/s]


Epoch: 51 Train loss: 982.6700


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1465.1897


100%|██████████| 469/469 [00:11<00:00, 39.71it/s]

Epoch: 52 Train loss: 981.8023



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1406.3025


100%|██████████| 469/469 [00:13<00:00, 33.81it/s]


Epoch: 53 Train loss: 981.2679


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1537.4674


100%|██████████| 469/469 [00:13<00:00, 33.61it/s]


Epoch: 54 Train loss: 979.1572


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1373.0255


100%|██████████| 469/469 [00:13<00:00, 35.48it/s]


Epoch: 55 Train loss: 978.8691


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1464.3260


100%|██████████| 469/469 [00:14<00:00, 35.48it/s]


Epoch: 56 Train loss: 978.2333


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1362.4904


100%|██████████| 469/469 [00:13<00:00, 34.88it/s]

Epoch: 57 Train loss: 977.2674



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1478.7488


100%|██████████| 469/469 [00:13<00:00, 35.68it/s]

Epoch: 58 Train loss: 976.7759



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1569.5299


100%|██████████| 469/469 [00:05<00:00, 78.68it/s]

Epoch: 59 Train loss: 976.7393



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1627.7656


100%|██████████| 469/469 [00:13<00:00, 35.75it/s]


Epoch: 60 Train loss: 975.3002


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1338.4237


100%|██████████| 469/469 [00:14<00:00, 35.36it/s]


Epoch: 61 Train loss: 974.0260


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1398.0234


100%|██████████| 469/469 [00:14<00:00, 32.66it/s]


Epoch: 62 Train loss: 972.7854


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1466.3756


100%|██████████| 469/469 [00:13<00:00, 33.98it/s]

Epoch: 63 Train loss: 972.4552



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1293.2844


100%|██████████| 469/469 [00:12<00:00, 36.46it/s]

Epoch: 64 Train loss: 972.1740



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1456.6141


100%|██████████| 469/469 [00:12<00:00, 37.09it/s]

Epoch: 65 Train loss: 971.0432



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1331.9482


100%|██████████| 469/469 [00:12<00:00, 36.84it/s]

Epoch: 66 Train loss: 970.2996



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1636.5682


100%|██████████| 469/469 [00:13<00:00, 35.59it/s]

Epoch: 67 Train loss: 970.1964



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1806.7886


100%|██████████| 469/469 [00:12<00:00, 38.93it/s]

Epoch: 68 Train loss: 968.6438



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1326.6925


100%|██████████| 469/469 [00:13<00:00, 34.94it/s]


Epoch: 69 Train loss: 967.9928


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1420.9402


100%|██████████| 469/469 [00:13<00:00, 34.42it/s]

Epoch: 70 Train loss: 968.4400



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1294.0026


100%|██████████| 469/469 [00:14<00:00, 33.30it/s]


Epoch: 71 Train loss: 967.1929


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1427.5278


100%|██████████| 469/469 [00:13<00:00, 38.58it/s]


Epoch: 72 Train loss: 966.6406


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1628.1786


100%|██████████| 469/469 [00:13<00:00, 33.62it/s]

Epoch: 73 Train loss: 965.4756



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1627.8982


100%|██████████| 469/469 [00:14<00:00, 33.34it/s]


Epoch: 74 Train loss: 966.1814


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1371.2775


100%|██████████| 469/469 [00:13<00:00, 34.96it/s]

Epoch: 75 Train loss: 964.9201



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1394.0204


100%|██████████| 469/469 [00:12<00:00, 39.07it/s]

Epoch: 76 Train loss: 963.9418



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1721.0089


100%|██████████| 469/469 [00:13<00:00, 35.20it/s]


Epoch: 77 Train loss: 964.1438


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1460.1500


100%|██████████| 469/469 [00:13<00:00, 35.04it/s]

Epoch: 78 Train loss: 961.8812



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1519.6156


100%|██████████| 469/469 [00:12<00:00, 36.24it/s]

Epoch: 79 Train loss: 962.1376



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1511.6824


100%|██████████| 469/469 [00:14<00:00, 32.93it/s]

Epoch: 80 Train loss: 962.3834



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1617.4530


100%|██████████| 469/469 [00:14<00:00, 35.34it/s]


Epoch: 81 Train loss: 961.6858


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1459.0834


100%|██████████| 469/469 [00:12<00:00, 37.41it/s]

Epoch: 82 Train loss: 960.9820



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1511.8270


100%|██████████| 469/469 [00:14<00:00, 34.91it/s]


Epoch: 83 Train loss: 960.1096


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1728.6542


100%|██████████| 469/469 [00:11<00:00, 39.81it/s]

Epoch: 84 Train loss: 960.5382



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1483.6656


100%|██████████| 469/469 [00:12<00:00, 36.12it/s]


Epoch: 85 Train loss: 959.5554


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2444.0093


100%|██████████| 469/469 [00:12<00:00, 36.35it/s]

Epoch: 86 Train loss: 958.8950



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1428.3088


100%|██████████| 469/469 [00:13<00:00, 35.61it/s]

Epoch: 87 Train loss: 958.7274



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1519.0203


100%|██████████| 469/469 [00:05<00:00, 82.41it/s]

Epoch: 88 Train loss: 958.2646



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1484.6084


100%|██████████| 469/469 [00:13<00:00, 35.23it/s]


Epoch: 89 Train loss: 957.5470


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1515.8263


100%|██████████| 469/469 [00:14<00:00, 35.48it/s]


Epoch: 90 Train loss: 956.7946


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1754.6790


100%|██████████| 469/469 [00:12<00:00, 37.49it/s]


Epoch: 91 Train loss: 957.0994


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1473.7362


100%|██████████| 469/469 [00:12<00:00, 37.03it/s]

Epoch: 92 Train loss: 955.6685



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1533.4181


100%|██████████| 469/469 [00:14<00:00, 35.25it/s]


Epoch: 93 Train loss: 955.2888


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1583.8872


100%|██████████| 469/469 [00:12<00:00, 37.16it/s]

Epoch: 94 Train loss: 955.5795



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1645.6635


100%|██████████| 469/469 [00:13<00:00, 40.53it/s]


Epoch: 95 Train loss: 955.1779


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1379.9139


100%|██████████| 469/469 [00:13<00:00, 35.31it/s]

Epoch: 96 Train loss: 954.1517



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1785.6316


100%|██████████| 469/469 [00:13<00:00, 34.59it/s]

Epoch: 97 Train loss: 953.8425



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1306.2423


100%|██████████| 469/469 [00:10<00:00, 46.39it/s]

Epoch: 98 Train loss: 953.6864



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1614.8352


100%|██████████| 469/469 [00:13<00:00, 33.72it/s]


Epoch: 99 Train loss: 953.0572


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1712.0837


100%|██████████| 469/469 [00:13<00:00, 34.26it/s]


Epoch: 100 Train loss: 952.6847
Test loss: 1455.8204
