In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 64
epochs = 100
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Deterministic, DataDistribution
from Tars.distributions import Normal
from Tars.models import GAN

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
x_dim = 784
z_dim = 100

# generator model p(x|z)    
class Generator(Deterministic):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(z_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, x_dim),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.model(z)
        return {"x": x}
    
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

In [5]:
# generative model
p_g = Generator()
p = (p_g*prior).marginalize_var("z")
p.to(device)

# data distribution
p_data = DataDistribution(["x"])
p_data.to(device)

print(p)
print(p_data)

Distribution:
  p(x) = ∫p(x|z)p_prior(z)dz
Network architecture:
  p_prior(z) (Normal): Normal()
  p(x|z) (Deterministic): Generator(
    (model): Sequential(
      (0): Linear(in_features=100, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace)
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2, inplace)
      (5): Linear(in_features=256, out_features=512, bias=True)
      (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2, inplace)
      (8): Linear(in_features=512, out_features=1024, bias=True)
      (9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.2, inplace)
      (11): Linear(in_features=1024, out_features=784, bias=True)
      (12): Sigmoid()
    )
  )
Distribution:


In [6]:
# discriminator model p(t|x)
class Discriminator(Deterministic):
    def __init__(self):
        super(Discriminator, self).__init__(cond_var=["x"], var=["t"], name="d")

        self.model = nn.Sequential(
            nn.Linear(x_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        t = self.model(x)
        return {"t": t}
    
d = Discriminator()
d.to(device)

print(d)

Distribution:
  d(t|x) (Deterministic)
Network architecture:
  Discriminator(
    (model): Sequential(
      (0): Linear(in_features=784, out_features=512, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace)
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.2, inplace)
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Sigmoid()
    )
  )


In [7]:
model = GAN(p_data, p, d,
            optimizer=optim.Adam, optimizer_params={"lr":0.0002},
            d_optimizer=optim.Adam, d_optimizer_params={"lr":0.0002})
print(model)

Distributions (for training): 
  p(x) 
Loss function: 
  mean(mean(AdversarialJSLoss[p_data(x)||p(x)]))


In [8]:
def train(epoch):
    train_loss = 0
    train_d_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        loss, d_loss = model.train({"x": data.view(-1, 784)})
        train_loss += loss
        train_d_loss += d_loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    train_d_loss = train_d_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}, {:.4f}'.format(epoch, train_loss.item(), train_d_loss.item()))
    return train_loss

In [9]:
def test(epoch):
    test_loss = 0
    test_d_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        loss, d_loss = model.test({"x": data.view(-1, 784)})
        test_loss += loss
        test_d_loss += d_loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    test_d_loss = test_d_loss * test_loader.batch_size / len(test_loader.dataset)
    
    print('Test loss: {:.4f}, {:.4f}'.format(test_loss, test_d_loss.item()))
    return test_loss

In [10]:
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p_g.sample({"z": z_sample})["x"].view(-1, 1, 28, 28).cpu()
        return sample

In [11]:
writer = SummaryWriter()

z_sample = torch.randn(64, z_dim).to(device)
x_original, y_original = iter(test_loader).next()
x_original = x_original.to(device)
y_original = y_original.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_image('Image_from_latent', sample, epoch)
    
writer.close()

100%|██████████| 938/938 [00:10<00:00, 93.53it/s]

Epoch: 1 Train loss: 5.6943, 0.1294



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 5.6846, 0.0571


100%|██████████| 938/938 [00:10<00:00, 91.67it/s]

Epoch: 2 Train loss: 6.0602, 0.0316



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 5.3000, 0.0386


100%|██████████| 938/938 [00:10<00:00, 86.29it/s]

Epoch: 3 Train loss: 7.2537, 0.0333



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 4.7635, 0.0610


100%|██████████| 938/938 [00:10<00:00, 85.82it/s]


Epoch: 4 Train loss: 7.5229, 0.0285


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 6.4531, 0.0342


100%|██████████| 938/938 [00:10<00:00, 92.45it/s]


Epoch: 5 Train loss: 7.8498, 0.0209


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 6.2918, 0.0255


100%|██████████| 938/938 [00:10<00:00, 91.48it/s]

Epoch: 6 Train loss: 8.5414, 0.0378



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 6.1425, 0.0450


100%|██████████| 938/938 [00:10<00:00, 88.47it/s]


Epoch: 7 Train loss: 7.7785, 0.0391


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 6.9806, 0.0137


100%|██████████| 938/938 [00:10<00:00, 87.27it/s]


Epoch: 8 Train loss: 7.6469, 0.0428


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 5.6198, 0.0412


100%|██████████| 938/938 [00:10<00:00, 88.21it/s]

Epoch: 9 Train loss: 7.6303, 0.0555



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 5.5796, 0.0512


100%|██████████| 938/938 [00:10<00:00, 88.47it/s]


Epoch: 10 Train loss: 7.3573, 0.0567


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 7.4477, 0.0497


100%|██████████| 938/938 [00:10<00:00, 90.10it/s]


Epoch: 11 Train loss: 7.5713, 0.0576


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 7.0544, 0.0443


100%|██████████| 938/938 [00:10<00:00, 87.11it/s]

Epoch: 12 Train loss: 6.9749, 0.0674



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 5.8631, 0.0625


100%|██████████| 938/938 [00:10<00:00, 87.45it/s]


Epoch: 13 Train loss: 7.5476, 0.0529


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 9.5732, 0.0745


100%|██████████| 938/938 [00:10<00:00, 89.27it/s]

Epoch: 14 Train loss: 7.7243, 0.0576



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 4.4562, 0.0513


100%|██████████| 938/938 [00:10<00:00, 88.37it/s]

Epoch: 15 Train loss: 7.8234, 0.0518



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 7.9974, 0.0508


100%|██████████| 938/938 [00:10<00:00, 90.15it/s]


Epoch: 16 Train loss: 7.3237, 0.0707


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 6.1275, 0.0575


100%|██████████| 938/938 [00:10<00:00, 89.59it/s]

Epoch: 17 Train loss: 6.6886, 0.0879



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 4.2862, 0.1004


100%|██████████| 938/938 [00:10<00:00, 85.71it/s]

Epoch: 18 Train loss: 5.9146, 0.1275



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 5.3511, 0.0968


100%|██████████| 938/938 [00:10<00:00, 92.17it/s]


Epoch: 19 Train loss: 5.2539, 0.1594


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 4.9115, 0.3250


100%|██████████| 938/938 [00:10<00:00, 89.48it/s]


Epoch: 20 Train loss: 4.9183, 0.1693


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 4.6865, 0.2504


100%|██████████| 938/938 [00:10<00:00, 86.14it/s]


Epoch: 21 Train loss: 4.4346, 0.2160


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.4838, 0.2124


100%|██████████| 938/938 [00:10<00:00, 85.94it/s]


Epoch: 22 Train loss: 4.2050, 0.2373


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.7069, 0.2281


100%|██████████| 938/938 [00:10<00:00, 88.51it/s]


Epoch: 23 Train loss: 4.2958, 0.2295


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 4.3092, 0.2727


100%|██████████| 938/938 [00:10<00:00, 86.40it/s]

Epoch: 24 Train loss: 3.8623, 0.2606



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.3282, 0.2800


100%|██████████| 938/938 [00:10<00:00, 87.60it/s]


Epoch: 25 Train loss: 3.6059, 0.2790


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.4032, 0.3181


100%|██████████| 938/938 [00:10<00:00, 86.81it/s]


Epoch: 26 Train loss: 3.5086, 0.2769


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.6333, 0.2683


100%|██████████| 938/938 [00:10<00:00, 87.71it/s]

Epoch: 27 Train loss: 3.6758, 0.2701



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.9564, 0.3142


100%|██████████| 938/938 [00:10<00:00, 91.37it/s]


Epoch: 28 Train loss: 3.3009, 0.3023


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.0251, 0.3240


100%|██████████| 938/938 [00:10<00:00, 88.00it/s]

Epoch: 29 Train loss: 3.1707, 0.3073



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.9076, 0.2718


100%|██████████| 938/938 [00:10<00:00, 88.01it/s]

Epoch: 30 Train loss: 3.2150, 0.3007



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.1819, 0.2905


100%|██████████| 938/938 [00:10<00:00, 90.12it/s]


Epoch: 31 Train loss: 3.2548, 0.3209


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.6954, 0.3291


100%|██████████| 938/938 [00:10<00:00, 87.56it/s]

Epoch: 32 Train loss: 2.9361, 0.3439



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.6709, 0.3479


100%|██████████| 938/938 [00:10<00:00, 86.34it/s]

Epoch: 33 Train loss: 2.8907, 0.3425



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.3013, 0.4087


100%|██████████| 938/938 [00:10<00:00, 88.36it/s]


Epoch: 34 Train loss: 2.9162, 0.3428


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1703, 0.3844


100%|██████████| 938/938 [00:10<00:00, 88.78it/s]


Epoch: 35 Train loss: 2.8683, 0.3406


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0000, 0.4227


100%|██████████| 938/938 [00:10<00:00, 87.87it/s]


Epoch: 36 Train loss: 2.8067, 0.3509


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.2323, 0.4666


100%|██████████| 938/938 [00:10<00:00, 87.37it/s]

Epoch: 37 Train loss: 2.8387, 0.3583



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 3.0491, 0.3574


100%|██████████| 938/938 [00:10<00:00, 91.44it/s]


Epoch: 38 Train loss: 2.6555, 0.3734


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1620, 0.4291


100%|██████████| 938/938 [00:10<00:00, 86.38it/s]


Epoch: 39 Train loss: 2.6562, 0.3844


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.4045, 0.4328


100%|██████████| 938/938 [00:10<00:00, 89.97it/s]


Epoch: 40 Train loss: 2.6021, 0.3966


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.3080, 0.4034


100%|██████████| 938/938 [00:10<00:00, 87.85it/s]

Epoch: 41 Train loss: 2.6139, 0.3865



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.2875, 0.3774


100%|██████████| 938/938 [00:10<00:00, 88.63it/s]


Epoch: 42 Train loss: 2.6012, 0.3991


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.4205, 0.3954


100%|██████████| 938/938 [00:10<00:00, 89.15it/s]

Epoch: 43 Train loss: 2.4621, 0.4176



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.7444, 0.4935


100%|██████████| 938/938 [00:09<00:00, 94.57it/s]


Epoch: 44 Train loss: 2.4326, 0.4302


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.6585, 0.4561


100%|██████████| 938/938 [00:10<00:00, 91.15it/s]


Epoch: 45 Train loss: 2.3984, 0.4312


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.5216, 0.4368


100%|██████████| 938/938 [00:10<00:00, 92.81it/s]

Epoch: 46 Train loss: 2.3928, 0.4389



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.2335, 0.4906


100%|██████████| 938/938 [00:10<00:00, 91.23it/s]


Epoch: 47 Train loss: 2.3974, 0.4429


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.2445, 0.4235


100%|██████████| 938/938 [00:09<00:00, 95.47it/s]


Epoch: 48 Train loss: 2.3720, 0.4415


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1912, 0.4581


100%|██████████| 938/938 [00:10<00:00, 88.70it/s]

Epoch: 49 Train loss: 2.3634, 0.4483



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.4202, 0.4438


100%|██████████| 938/938 [00:10<00:00, 86.30it/s]

Epoch: 50 Train loss: 2.3715, 0.4514



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.3270, 0.4406


100%|██████████| 938/938 [00:10<00:00, 90.10it/s]


Epoch: 51 Train loss: 2.3210, 0.4659


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9439, 0.5236


100%|██████████| 938/938 [00:10<00:00, 88.74it/s]

Epoch: 52 Train loss: 2.3219, 0.4662



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.3312, 0.5203


100%|██████████| 938/938 [00:10<00:00, 89.66it/s]


Epoch: 53 Train loss: 2.2695, 0.4799


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.3660, 0.4859


100%|██████████| 938/938 [00:10<00:00, 90.73it/s]


Epoch: 54 Train loss: 2.2974, 0.4824


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.8367, 0.5243


100%|██████████| 938/938 [00:10<00:00, 91.34it/s]


Epoch: 55 Train loss: 2.2826, 0.4808


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1331, 0.4906


100%|██████████| 938/938 [00:10<00:00, 85.69it/s]

Epoch: 56 Train loss: 2.2162, 0.5049



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.8869, 0.5964


100%|██████████| 938/938 [00:10<00:00, 89.55it/s]

Epoch: 57 Train loss: 2.2211, 0.5073



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1162, 0.4999


100%|██████████| 938/938 [00:10<00:00, 87.50it/s]

Epoch: 58 Train loss: 2.2112, 0.5149



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1307, 0.5385


100%|██████████| 938/938 [00:10<00:00, 88.85it/s]


Epoch: 59 Train loss: 2.1798, 0.5151


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.2865, 0.5748


100%|██████████| 938/938 [00:10<00:00, 87.84it/s]

Epoch: 60 Train loss: 2.1652, 0.5364



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.3050, 0.6133


100%|██████████| 938/938 [00:10<00:00, 90.30it/s]

Epoch: 61 Train loss: 2.1624, 0.5297



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9108, 0.5677


100%|██████████| 938/938 [00:10<00:00, 89.36it/s]

Epoch: 62 Train loss: 2.1581, 0.5340



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1714, 0.5994


100%|██████████| 938/938 [00:10<00:00, 85.44it/s]

Epoch: 63 Train loss: 2.1512, 0.5354



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.4440, 0.6085


100%|██████████| 938/938 [00:10<00:00, 88.79it/s]

Epoch: 64 Train loss: 2.1282, 0.5462



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1064, 0.5828


100%|██████████| 938/938 [00:10<00:00, 86.50it/s]


Epoch: 65 Train loss: 2.1500, 0.5395


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0011, 0.5878


100%|██████████| 938/938 [00:10<00:00, 86.88it/s]

Epoch: 66 Train loss: 2.1457, 0.5379



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9861, 0.6681


100%|██████████| 938/938 [00:10<00:00, 91.50it/s]

Epoch: 67 Train loss: 2.1073, 0.5587



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.2502, 0.5706


100%|██████████| 938/938 [00:10<00:00, 90.19it/s]

Epoch: 68 Train loss: 2.1201, 0.5519



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0225, 0.5677


100%|██████████| 938/938 [00:10<00:00, 91.01it/s]

Epoch: 69 Train loss: 2.1116, 0.5608



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.6723, 0.6772


100%|██████████| 938/938 [00:10<00:00, 86.69it/s]


Epoch: 70 Train loss: 2.0683, 0.5755


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.8415, 0.6212


100%|██████████| 938/938 [00:10<00:00, 90.48it/s]


Epoch: 71 Train loss: 2.0717, 0.5641


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9805, 0.6528


100%|██████████| 938/938 [00:10<00:00, 85.75it/s]


Epoch: 72 Train loss: 2.0472, 0.5766


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0743, 0.6666


100%|██████████| 938/938 [00:10<00:00, 85.77it/s]


Epoch: 73 Train loss: 2.1001, 0.5714


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.2202, 0.6155


100%|██████████| 938/938 [00:10<00:00, 86.97it/s]


Epoch: 74 Train loss: 2.0938, 0.5671


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.7363, 0.6332


100%|██████████| 938/938 [00:10<00:00, 89.18it/s]


Epoch: 75 Train loss: 2.0594, 0.5698


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.8646, 0.6503


100%|██████████| 938/938 [00:10<00:00, 87.69it/s]


Epoch: 76 Train loss: 2.0372, 0.5892


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.8371, 0.6538


100%|██████████| 938/938 [00:10<00:00, 87.88it/s]

Epoch: 77 Train loss: 2.0673, 0.5889



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9349, 0.6439


100%|██████████| 938/938 [00:10<00:00, 90.40it/s]

Epoch: 78 Train loss: 2.0542, 0.5863



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1714, 0.6218


100%|██████████| 938/938 [00:10<00:00, 89.09it/s]

Epoch: 79 Train loss: 2.0313, 0.5876



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9326, 0.6266


100%|██████████| 938/938 [00:10<00:00, 87.43it/s]

Epoch: 80 Train loss: 2.0471, 0.5882



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.6722, 0.7075


100%|██████████| 938/938 [00:10<00:00, 87.36it/s]


Epoch: 81 Train loss: 2.0324, 0.5930


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9686, 0.6852


100%|██████████| 938/938 [00:10<00:00, 89.47it/s]

Epoch: 82 Train loss: 2.0280, 0.5983



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.2242, 0.7205


100%|██████████| 938/938 [00:10<00:00, 88.88it/s]

Epoch: 83 Train loss: 2.0303, 0.5973



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.8775, 0.6895


100%|██████████| 938/938 [00:10<00:00, 90.47it/s]


Epoch: 84 Train loss: 2.0591, 0.5828


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0708, 0.6816


100%|██████████| 938/938 [00:10<00:00, 88.06it/s]


Epoch: 85 Train loss: 2.0186, 0.5947


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.1972, 0.6842


100%|██████████| 938/938 [00:10<00:00, 89.34it/s]


Epoch: 86 Train loss: 2.0278, 0.5944


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0698, 0.7015


100%|██████████| 938/938 [00:09<00:00, 95.88it/s]


Epoch: 87 Train loss: 2.0408, 0.5907


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9480, 0.6382


100%|██████████| 938/938 [00:10<00:00, 92.33it/s]


Epoch: 88 Train loss: 2.0349, 0.5968


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0149, 0.6805


100%|██████████| 938/938 [00:10<00:00, 87.36it/s]

Epoch: 89 Train loss: 2.0156, 0.6021



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0140, 0.6864


100%|██████████| 938/938 [00:10<00:00, 90.72it/s]

Epoch: 90 Train loss: 2.0216, 0.6023



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0403, 0.6706


100%|██████████| 938/938 [00:10<00:00, 87.86it/s]

Epoch: 91 Train loss: 2.0494, 0.5872



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.7591, 0.6531


100%|██████████| 938/938 [00:10<00:00, 86.00it/s]


Epoch: 92 Train loss: 2.0337, 0.5912


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9150, 0.6305


100%|██████████| 938/938 [00:09<00:00, 95.24it/s]


Epoch: 93 Train loss: 2.0724, 0.5885


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0522, 0.6751


100%|██████████| 938/938 [00:10<00:00, 87.18it/s]


Epoch: 94 Train loss: 2.0123, 0.5939


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 2.0578, 0.6974


100%|██████████| 938/938 [00:10<00:00, 87.14it/s]

Epoch: 95 Train loss: 2.0565, 0.5956



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9611, 0.6719


100%|██████████| 938/938 [00:11<00:00, 84.14it/s]


Epoch: 96 Train loss: 2.0594, 0.5853


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9946, 0.6553


100%|██████████| 938/938 [00:10<00:00, 88.15it/s]

Epoch: 97 Train loss: 2.0322, 0.5938



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.9234, 0.6635


100%|██████████| 938/938 [00:10<00:00, 91.54it/s]


Epoch: 98 Train loss: 2.0384, 0.5950


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.8343, 0.7663


100%|██████████| 938/938 [00:10<00:00, 85.33it/s]


Epoch: 99 Train loss: 2.0217, 0.6005


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 1.6456, 0.7277


100%|██████████| 938/938 [00:10<00:00, 90.78it/s]


Epoch: 100 Train loss: 2.0238, 0.6004
Test loss: 1.9871, 0.7018
