In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 100
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Deterministic, DataDistribution
from Tars.distributions import Normal
from Tars.models import GAN

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
x_dim = 784
z_dim = 64

# generator model p(x|z)    
class Generator(Deterministic):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(512)        

    def forward(self, z):
        h = F.relu(self.bn1(self.fc1(z)))
        h = F.relu(self.bn2(self.fc2(h)))
        return {"x": F.sigmoid(self.fc3(h))}
    
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

In [5]:
# generative model
p_g = Generator()
p = (p_g*prior).marginalize_var("z")
p.to(device)

# data distribution
p_data = DataDistribution(["x"])
p_data.to(device)

print(p)
print(p_data)

Distribution:
  p(x) = ∫p(x|z)p_prior(z)dz
Network architecture:
  p_prior(z) (Normal): Normal()
  p(x|z) (Deterministic): Generator(
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
    (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
Distribution:
  p_data(x) (Data distribution)
Network architecture:
  DataDistribution()


In [6]:
# discriminator model p(t|x)
class Discriminator(Deterministic):
    def __init__(self):
        super(Discriminator, self).__init__(cond_var=["x"], var=["t"], name="d")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 1)

    def forward(self, x):
        h = F.leaky_relu(self.fc1(x))
        h = F.leaky_relu(self.fc2(h))
        return {"t": F.sigmoid(self.fc3(h))}
    
d = Discriminator()
d.to(device)

print(d)

Distribution:
  d(t|x) (Deterministic)
Network architecture:
  Discriminator(
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=1, bias=True)
  )


In [7]:
model = GAN(p_data, p, d,
            optimizer=optim.Adam, optimizer_params={"lr":1e-4},
            d_optimizer=optim.Adam, d_optimizer_params={"lr":1e-4})
print(model)

Distributions (for training): 
  p(x) 
Loss function: 
  mean(GANLoss[p_data(x)||p(x)])


In [8]:
def train(epoch):
    train_loss = 0
    train_d_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        loss, d_loss = model.train({"x": data.view(-1, 784)})
        train_loss += loss
        train_d_loss += d_loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    train_d_loss = train_d_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}, {:.4f}'.format(epoch, train_loss.item(), train_d_loss.item()))
    return train_loss

In [9]:
def test(epoch):
    test_loss = 0
    test_d_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        loss, d_loss = model.test({"x": data.view(-1, 784)})
        test_loss += loss
        test_d_loss += d_loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    test_d_loss = test_d_loss * test_loader.batch_size / len(test_loader.dataset)
    
    print('Test loss: {:.4f}, {:.4f}'.format(test_loss, test_d_loss.item()))
    return test_loss

In [10]:
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p_g.sample({"z": z_sample})["x"].view(-1, 1, 28, 28).cpu()
        return sample

In [11]:
writer = SummaryWriter()

z_sample = torch.randn(64, z_dim).to(device)
x_original, y_original = iter(test_loader).next()
x_original = x_original.to(device)
y_original = y_original.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_image('Image_from_latent', sample, epoch)
    
writer.close()

100%|██████████| 469/469 [00:05<00:00, 85.21it/s]

Epoch: 1 Train loss: 4.8727, 0.1210



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 5.7470, 5.7492


100%|██████████| 469/469 [00:05<00:00, 84.07it/s]


Epoch: 2 Train loss: 7.6702, 0.0255


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.5574, 7.5432


100%|██████████| 469/469 [00:05<00:00, 79.86it/s]

Epoch: 3 Train loss: 6.3541, 0.0138



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.1354, 7.1500


100%|██████████| 469/469 [00:05<00:00, 80.23it/s]

Epoch: 4 Train loss: 7.3728, 0.0089



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.5708, 7.5634


100%|██████████| 469/469 [00:05<00:00, 81.75it/s]

Epoch: 5 Train loss: 7.7578, 0.0084



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.5364, 7.5415


100%|██████████| 469/469 [00:05<00:00, 83.45it/s]


Epoch: 6 Train loss: 7.8456, 0.0054


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 9.1074, 9.1359


100%|██████████| 469/469 [00:05<00:00, 85.06it/s]


Epoch: 7 Train loss: 8.4428, 0.0057


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.2680, 7.2729


100%|██████████| 469/469 [00:05<00:00, 83.32it/s]


Epoch: 8 Train loss: 8.2219, 0.0042


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 8.5503, 8.5811


100%|██████████| 469/469 [00:06<00:00, 76.01it/s]

Epoch: 9 Train loss: 8.3394, 0.0073



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 8.3759, 8.3424


100%|██████████| 469/469 [00:06<00:00, 77.06it/s]

Epoch: 10 Train loss: 8.4487, 0.0061



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 10.2424, 10.2360


100%|██████████| 469/469 [00:06<00:00, 77.66it/s]

Epoch: 11 Train loss: 8.8904, 0.0058



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 9.8673, 9.8885


100%|██████████| 469/469 [00:05<00:00, 79.26it/s]

Epoch: 12 Train loss: 8.8049, 0.0089



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.5957, 7.5931


100%|██████████| 469/469 [00:05<00:00, 86.00it/s]

Epoch: 13 Train loss: 9.0651, 0.0112



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 11.1487, 11.1022


100%|██████████| 469/469 [00:05<00:00, 81.52it/s]

Epoch: 14 Train loss: 9.2535, 0.0137



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 10.4025, 10.3933


100%|██████████| 469/469 [00:06<00:00, 74.65it/s]


Epoch: 15 Train loss: 9.5846, 0.0140


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 9.1895, 9.2996


100%|██████████| 469/469 [00:05<00:00, 80.06it/s]

Epoch: 16 Train loss: 8.5857, 0.0177



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.3978, 7.4027


100%|██████████| 469/469 [00:05<00:00, 82.14it/s]

Epoch: 17 Train loss: 9.0401, 0.0180



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 8.6889, 8.7161


100%|██████████| 469/469 [00:05<00:00, 83.12it/s]


Epoch: 18 Train loss: 8.4241, 0.0320


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.3754, 7.4062


100%|██████████| 469/469 [00:05<00:00, 79.32it/s]

Epoch: 19 Train loss: 7.9453, 0.0226



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 9.4182, 9.4508


100%|██████████| 469/469 [00:05<00:00, 81.77it/s]

Epoch: 20 Train loss: 8.7687, 0.0277



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 8.8975, 8.9380


100%|██████████| 469/469 [00:05<00:00, 80.17it/s]

Epoch: 21 Train loss: 8.5050, 0.0291



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 9.5009, 9.4528


100%|██████████| 469/469 [00:05<00:00, 83.47it/s]

Epoch: 22 Train loss: 8.5142, 0.0321



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 8.9490, 8.8932


100%|██████████| 469/469 [00:05<00:00, 82.01it/s]

Epoch: 23 Train loss: 7.7413, 0.0385



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.9812, 8.0319


100%|██████████| 469/469 [00:05<00:00, 82.33it/s]

Epoch: 24 Train loss: 7.5998, 0.0414



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.3966, 7.3751


100%|██████████| 469/469 [00:05<00:00, 84.47it/s]

Epoch: 25 Train loss: 7.5826, 0.0529



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 8.7781, 8.7937


100%|██████████| 469/469 [00:05<00:00, 81.36it/s]

Epoch: 26 Train loss: 7.3201, 0.0447



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.2429, 7.2936


100%|██████████| 469/469 [00:06<00:00, 77.72it/s]


Epoch: 27 Train loss: 7.0561, 0.0481


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 6.0800, 6.0625


100%|██████████| 469/469 [00:05<00:00, 79.45it/s]

Epoch: 28 Train loss: 6.6558, 0.0527



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.3997, 7.4027


100%|██████████| 469/469 [00:05<00:00, 80.64it/s]

Epoch: 29 Train loss: 6.6968, 0.0605



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 7.1107, 7.0877


100%|██████████| 469/469 [00:05<00:00, 85.40it/s]

Epoch: 30 Train loss: 6.5725, 0.0732



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 5.5990, 5.5915


100%|██████████| 469/469 [00:05<00:00, 81.03it/s]


Epoch: 31 Train loss: 6.1547, 0.0855


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 6.0022, 5.9962


100%|██████████| 469/469 [00:05<00:00, 78.28it/s]

Epoch: 32 Train loss: 5.7931, 0.0907



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 6.4214, 6.3804


100%|██████████| 469/469 [00:05<00:00, 80.94it/s]

Epoch: 33 Train loss: 5.7018, 0.0964



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 6.2498, 6.2595


100%|██████████| 469/469 [00:05<00:00, 80.99it/s]


Epoch: 34 Train loss: 5.3713, 0.1065


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.5155, 4.5047


100%|██████████| 469/469 [00:05<00:00, 80.25it/s]

Epoch: 35 Train loss: 5.4398, 0.1208



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 5.9631, 5.9303


100%|██████████| 469/469 [00:05<00:00, 82.73it/s]

Epoch: 36 Train loss: 5.3510, 0.1228



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 5.7821, 5.8141


100%|██████████| 469/469 [00:05<00:00, 81.51it/s]

Epoch: 37 Train loss: 5.5150, 0.1332



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 5.2103, 5.2072


100%|██████████| 469/469 [00:05<00:00, 79.78it/s]

Epoch: 38 Train loss: 5.5081, 0.1027



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 5.0082, 5.0110


100%|██████████| 469/469 [00:05<00:00, 82.40it/s]

Epoch: 39 Train loss: 4.8625, 0.1639



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.2034, 4.2026


100%|██████████| 469/469 [00:05<00:00, 83.94it/s]

Epoch: 40 Train loss: 4.8805, 0.1533



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.4736, 4.4692


100%|██████████| 469/469 [00:06<00:00, 77.19it/s]


Epoch: 41 Train loss: 4.6104, 0.1681


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.8614, 3.8359


100%|██████████| 469/469 [00:06<00:00, 75.51it/s]

Epoch: 42 Train loss: 4.4763, 0.1682



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.4567, 4.4427


100%|██████████| 469/469 [00:05<00:00, 82.77it/s]

Epoch: 43 Train loss: 4.3404, 0.1834



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.8291, 3.8793


100%|██████████| 469/469 [00:05<00:00, 78.25it/s]


Epoch: 44 Train loss: 4.2799, 0.1869


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.0519, 4.0366


100%|██████████| 469/469 [00:05<00:00, 79.00it/s]

Epoch: 45 Train loss: 4.5871, 0.1865



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.2239, 4.2245


100%|██████████| 469/469 [00:05<00:00, 83.70it/s]

Epoch: 46 Train loss: 4.3916, 0.1973



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.1353, 4.1296


100%|██████████| 469/469 [00:05<00:00, 79.29it/s]

Epoch: 47 Train loss: 4.4132, 0.1892



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.0659, 4.0442


100%|██████████| 469/469 [00:05<00:00, 81.57it/s]

Epoch: 48 Train loss: 4.0725, 0.2284



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.9412, 3.9161


100%|██████████| 469/469 [00:05<00:00, 81.96it/s]


Epoch: 49 Train loss: 3.8197, 0.2365


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.7970, 3.7985


100%|██████████| 469/469 [00:05<00:00, 83.43it/s]

Epoch: 50 Train loss: 4.1739, 0.1922



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.0446, 4.0377


100%|██████████| 469/469 [00:05<00:00, 82.38it/s]


Epoch: 51 Train loss: 3.9251, 0.2514


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.8784, 3.8590


100%|██████████| 469/469 [00:05<00:00, 79.74it/s]

Epoch: 52 Train loss: 3.8975, 0.2396



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 4.1256, 4.1125


100%|██████████| 469/469 [00:05<00:00, 83.36it/s]

Epoch: 53 Train loss: 3.6287, 0.2851



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.5459, 3.5282


100%|██████████| 469/469 [00:05<00:00, 80.56it/s]


Epoch: 54 Train loss: 3.6782, 0.2480


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.4938, 3.4962


100%|██████████| 469/469 [00:05<00:00, 82.94it/s]

Epoch: 55 Train loss: 3.2807, 0.3230



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.3619, 3.3548


100%|██████████| 469/469 [00:05<00:00, 80.34it/s]

Epoch: 56 Train loss: 3.4424, 0.2733



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.1165, 3.1228


100%|██████████| 469/469 [00:05<00:00, 79.63it/s]

Epoch: 57 Train loss: 3.4298, 0.2972



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.5061, 3.4939


100%|██████████| 469/469 [00:05<00:00, 83.36it/s]

Epoch: 58 Train loss: 3.3762, 0.3012



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.8890, 3.8804


100%|██████████| 469/469 [00:06<00:00, 77.07it/s]

Epoch: 59 Train loss: 3.3281, 0.3252



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.4062, 3.4065


100%|██████████| 469/469 [00:05<00:00, 82.16it/s]


Epoch: 60 Train loss: 3.2812, 0.3254


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.4806, 3.4734


100%|██████████| 469/469 [00:05<00:00, 81.18it/s]


Epoch: 61 Train loss: 3.1411, 0.3514


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.3226, 3.3020


100%|██████████| 469/469 [00:05<00:00, 81.40it/s]

Epoch: 62 Train loss: 3.1582, 0.3370



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.7652, 3.7407


100%|██████████| 469/469 [00:05<00:00, 82.25it/s]

Epoch: 63 Train loss: 3.1711, 0.3609



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7193, 2.7086


100%|██████████| 469/469 [00:05<00:00, 83.47it/s]

Epoch: 64 Train loss: 3.3703, 0.3320



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.0791, 3.0693


100%|██████████| 469/469 [00:05<00:00, 79.83it/s]

Epoch: 65 Train loss: 3.0861, 0.3613



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.1449, 3.1537


100%|██████████| 469/469 [00:05<00:00, 85.63it/s]


Epoch: 66 Train loss: 2.9566, 0.3736


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.0112, 3.0125


100%|██████████| 469/469 [00:05<00:00, 81.95it/s]

Epoch: 67 Train loss: 2.8062, 0.3948



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.0503, 3.0476


100%|██████████| 469/469 [00:05<00:00, 79.69it/s]


Epoch: 68 Train loss: 3.6941, 0.3547


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.3671, 3.3787


100%|██████████| 469/469 [00:05<00:00, 81.61it/s]

Epoch: 69 Train loss: 3.2291, 0.3733



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.8920, 2.8987


100%|██████████| 469/469 [00:05<00:00, 82.43it/s]

Epoch: 70 Train loss: 2.8076, 0.3946



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.0768, 3.0718


100%|██████████| 469/469 [00:05<00:00, 80.96it/s]

Epoch: 71 Train loss: 2.7206, 0.3937



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.5593, 2.5479


100%|██████████| 469/469 [00:06<00:00, 77.42it/s]

Epoch: 72 Train loss: 2.7531, 0.4003



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7951, 2.8060


100%|██████████| 469/469 [00:06<00:00, 76.93it/s]


Epoch: 73 Train loss: 2.7854, 0.4072


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.0523, 3.0264


100%|██████████| 469/469 [00:05<00:00, 80.09it/s]


Epoch: 74 Train loss: 2.7343, 0.4131


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.5013, 2.4979


100%|██████████| 469/469 [00:05<00:00, 78.67it/s]


Epoch: 75 Train loss: 2.7037, 0.4067


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.8189, 2.8348


100%|██████████| 469/469 [00:05<00:00, 81.60it/s]

Epoch: 76 Train loss: 2.9212, 0.4185



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.8285, 2.8287


100%|██████████| 469/469 [00:05<00:00, 82.25it/s]


Epoch: 77 Train loss: 2.7423, 0.4220


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.0378, 3.0346


100%|██████████| 469/469 [00:05<00:00, 79.37it/s]


Epoch: 78 Train loss: 2.6433, 0.4374


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.5484, 2.5570


100%|██████████| 469/469 [00:06<00:00, 77.70it/s]

Epoch: 79 Train loss: 2.6196, 0.4364



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7582, 2.7362


100%|██████████| 469/469 [00:05<00:00, 81.12it/s]

Epoch: 80 Train loss: 2.6641, 0.4323



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.3738, 2.3829


100%|██████████| 469/469 [00:05<00:00, 83.21it/s]

Epoch: 81 Train loss: 2.5907, 0.4381



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7372, 2.7501


100%|██████████| 469/469 [00:05<00:00, 85.00it/s]


Epoch: 82 Train loss: 2.8264, 0.4236


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.2640, 3.2271


100%|██████████| 469/469 [00:05<00:00, 83.19it/s]


Epoch: 83 Train loss: 2.6514, 0.5015


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.1584, 3.1951


100%|██████████| 469/469 [00:05<00:00, 85.10it/s]

Epoch: 84 Train loss: 2.9931, 0.4157



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7349, 2.7566


100%|██████████| 469/469 [00:05<00:00, 82.50it/s]

Epoch: 85 Train loss: 2.9665, 0.4457



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.6848, 2.6879


100%|██████████| 469/469 [00:05<00:00, 86.14it/s]

Epoch: 86 Train loss: 2.5732, 0.4703



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7085, 2.7147


100%|██████████| 469/469 [00:05<00:00, 82.71it/s]

Epoch: 87 Train loss: 2.4866, 0.4741



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.4415, 2.4277


100%|██████████| 469/469 [00:05<00:00, 79.01it/s]

Epoch: 88 Train loss: 2.5943, 0.4768



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.3626, 2.3830


100%|██████████| 469/469 [00:05<00:00, 82.74it/s]


Epoch: 89 Train loss: 2.9044, 0.4397


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.9345, 2.9281


100%|██████████| 469/469 [00:05<00:00, 84.27it/s]


Epoch: 90 Train loss: 3.2576, 0.3771


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.5431, 3.5718


100%|██████████| 469/469 [00:05<00:00, 82.10it/s]

Epoch: 91 Train loss: 3.0331, 0.4623



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.5654, 2.5480


100%|██████████| 469/469 [00:05<00:00, 81.44it/s]

Epoch: 92 Train loss: 2.5573, 0.4520



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.8149, 2.8500


100%|██████████| 469/469 [00:05<00:00, 78.46it/s]


Epoch: 93 Train loss: 2.6586, 0.4842


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.8076, 2.7991


100%|██████████| 469/469 [00:05<00:00, 81.07it/s]

Epoch: 94 Train loss: 2.6907, 0.4890



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.0261, 3.0424


100%|██████████| 469/469 [00:05<00:00, 82.72it/s]

Epoch: 95 Train loss: 2.6289, 0.4767



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7313, 2.7030


100%|██████████| 469/469 [00:05<00:00, 83.11it/s]

Epoch: 96 Train loss: 2.2909, 0.6307



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.1476, 2.1375


100%|██████████| 469/469 [00:05<00:00, 84.12it/s]


Epoch: 97 Train loss: 2.8134, 0.4544


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.4441, 2.4536


100%|██████████| 469/469 [00:05<00:00, 81.97it/s]


Epoch: 98 Train loss: 2.5418, 0.4890


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.6820, 2.6780


100%|██████████| 469/469 [00:05<00:00, 84.58it/s]

Epoch: 99 Train loss: 2.8476, 0.5017



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.8283, 2.8097


100%|██████████| 469/469 [00:05<00:00, 83.44it/s]

Epoch: 100 Train loss: 2.5926, 0.5962





Test loss: 2.9826, 2.9595
