In [13]:
import torch 

from torch import nn
from torch import optim

from torchvision.datasets import MNIST

from torch.utils.data import  TensorDataset, Dataset, DataLoader


import numpy as np
import tqdm

# from torch.autograd import Variable

In [14]:
from tensorboardX import SummaryWriter

In [15]:
beta = 1e-3
batch_size = 100
samples_amount = 10
num_epochs = 10

In [16]:
torch.cuda.set_device(4)

Prepare the data:

In [18]:
train_data = MNIST('mnist', download=True, train=True)
train_dataset = TensorDataset(train_data.train_data.view(-1, 28 * 28).float() / 255, train_data.train_labels)
train_loader = DataLoader(train_dataset, batch_size=batch_size)

test_data = MNIST('mnist', download=True, train=False)
test_dataset = TensorDataset(test_data.test_data.view(-1, 28 * 28).float() / 255, test_data.test_labels)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [19]:
def KL_between_normals(q_distr, p_distr):
    mu_q, sigma_q = q_distr
    mu_p, sigma_p = p_distr
    k = mu_q.size(1)
    
    mu_diff = mu_p - mu_q
    mu_diff_sq =  torch.mul(mu_diff, mu_diff)
    logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1)
    logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1)
    
    fs = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1)  + torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1)
    two_kl =  fs - k + logdet_sigma_p - logdet_sigma_q
    return two_kl * 0.5

In [20]:
class VIB(nn.Module):
    def __init__(self, X_dim, y_dim, dimZ=256, beta=1e-3, num_samples=10):
        # the dimension of Z 
        super().__init__()
        
        self.beta = beta
        self.dimZ = dimZ 
        self.num_samples = num_samples
        
        self.encoder = nn.Sequential(nn.Linear(in_features=X_dim, out_features=1024),
                                     nn.ReLU(),
                                     nn.Linear(in_features=1024, out_features=1024),
                                     nn.ReLU(),
                                     nn.Linear(in_features=1024, out_features=2 * self.dimZ))
        
        #  TODO: try heads        
        #         self.encoder_sigma_head = nn.Linear()
        #         self.encoder_mu_head = ...
        
        
        
        # decoder a simple logistic regression as in the paper 
        self.decoder_logits = nn.Linear(in_features=self.dimZ, out_features=y_dim)

    def gaussian_noise(self, num_samples, K):
        # works with integers as well as tuples   
        return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).cuda()
           
    def sample_prior_Z(self, num_samples):
        return self.gaussian_noise(num_samples=num_samples, K=self.dimZ)

    def encoder_result(self, batch):
        encoder_output = self.encoder(batch)
        
        mu = encoder_output[:, :self.dimZ]
        sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:])
        
        return mu, sigma
    
    def sample_encoder_Z(self, num_samples, batch): 
        batch_size = batch.size()[0]
        mu, sigma = self.encoder_result(batch)
        
        return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ)
    
    
    def forward(self, batch_x):
        
        batch_size = batch_x.size()[0]
        
        # sample from encoder
        encoder_Z_distr = self.encoder_result(batch_x)  
        to_decoder = self.sample_encoder_Z(num_samples=self.num_samples, batch=batch_x)

        decoder_logits_mean = torch.mean(self.decoder_logits(to_decoder), dim=0)
                
        return decoder_logits_mean
        
    def batch_loss(self, num_samples, batch_x, batch_y):
        batch_size = batch_x.size()[0]
        
        prior_Z_distr = torch.zeros(batch_size, self.dimZ).cuda(), torch.ones(batch_size, self.dimZ).cuda()
        encoder_Z_distr = self.encoder_result(batch_x)
        
        
        I_ZX_bound = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr))
        
        
        to_decoder = self.sample_encoder_Z(num_samples=self.num_samples, batch=batch_x)
        
        
        decoder_logits = self.decoder_logits(to_decoder)
        # batch should go first
        decoder_logits = decoder_logits.permute(1, 2, 0)
    
        
        loss = nn.CrossEntropyLoss(reduce=False)
        cross_entropy_loss = loss(decoder_logits, batch_y[:, None].expand(-1, num_samples))
        
        # estimate E_{eps in N(0, 1)} [log q(y | z)] 
        cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
        
        minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0)
                
        return torch.mean(minusI_ZY_bound + self.beta * I_ZX_bound),  -minusI_ZY_bound, I_ZX_bound

In [34]:
beta = 1e-3
batch_size = 100
samples_amount = 30
num_epochs = 200

In [35]:
model = VIB(X_dim=784, y_dim=10, beta = beta, num_samples=samples_amount).cuda()

opt = torch.optim.Adam(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.97)

In [36]:
class EMA(nn.Module):
    def __init__(self, mu):
        super(EMA, self).__init__()
        self.mu = mu
        self.shadow = {}

    def register(self, name, val):
        self.shadow[name] = val.clone()

    def forward(self, name, x):
        assert name in self.shadow
        new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name]
        self.shadow[name] = new_average.clone()
        return new_average

ema = EMA(0.999)
for name, param in model.named_parameters():
    if param.requires_grad:
        ema.register(name, param.data)

In [37]:
import time


seed = time.strftime("%Y-%m-%d %H:%M")

writer = SummaryWriter(log_dir="tensor_logs/" + seed)


for epoch in range(num_epochs):
    loss_by_epoch = []
    accuracy_by_epoch = []
    I_ZX_bound_by_epoch = []
    I_ZY_bound_by_epoch = []
    
    loss_by_epoch_test = []
    accuracy_by_epoch_test = []
    I_ZX_bound_by_epoch_test = []
    I_ZY_bound_by_epoch_test = []
    
    if epoch % 2 == 0 and epoch > 0:
        scheduler.step()
    
    for x_batch, y_batch in tqdm.tqdm(train_loader):   
        x_batch = x_batch.cuda()
        y_batch = y_batch.cuda()

        loss, I_ZY_bound, I_ZX_bound = model.batch_loss(samples_amount, x_batch, y_batch)
        
        logits = model.forward(x_batch)
        prediction = torch.max(logits, dim=1)[1]
        accuracy = torch.mean((prediction == y_batch).float())


        loss.backward()
        opt.step()
        opt.zero_grad()
        
        # compute exponential moving average
        for name, param in model.named_parameters():
            if param.requires_grad:
                ema(name, param.data)
        
        I_ZX_bound_by_epoch.append(I_ZX_bound.item())
        I_ZY_bound_by_epoch.append(I_ZY_bound.item())
        
        loss_by_epoch.append(loss.item())
        accuracy_by_epoch.append(accuracy.item())
        
        
    for x_batch, y_batch in tqdm.tqdm(test_loader):
        x_batch = x_batch.cuda()
        y_batch = y_batch.cuda()

        loss, I_ZY_bound, I_ZX_bound = model.batch_loss(samples_amount, x_batch, y_batch)
        
        logits = model.forward(x_batch)
        prediction = torch.max(logits, dim=1)[1]
        accuracy = torch.mean((prediction == y_batch).float())

        
        I_ZX_bound_by_epoch_test.append(I_ZX_bound.item())
        I_ZY_bound_by_epoch_test.append(I_ZY_bound.item())
        
        loss_by_epoch_test.append(loss.item())
        accuracy_by_epoch_test.append(accuracy.item())
        
    writer.add_scalar("accuracy", np.mean(accuracy_by_epoch_test), global_step=epoch)
    writer.add_scalar("loss", np.mean(loss_by_epoch_test), global_step=epoch)
    writer.add_scalar("I_ZX", np.mean(I_ZX_bound_by_epoch_test), global_step=epoch)
    writer.add_scalar("I_ZY", np.mean(I_ZY_bound_by_epoch_test), global_step=epoch)

   
    print('epoch', epoch, 'loss', np.mean(loss_by_epoch_test), 
          'prediction', np.mean(accuracy_by_epoch_test))
          
    print('I_ZX_bound', np.mean(I_ZX_bound_by_epoch_test), 
          'I_ZY_bound', np.mean(I_ZY_bound_by_epoch_test))

100%|██████████| 600/600 [00:21<00:00, 28.32it/s]
100%|██████████| 100/100 [00:02<00:00, 33.93it/s]
  1%|          | 4/600 [00:00<00:17, 34.96it/s]

epoch 0 loss 0.30287878297269344 prediction 0.9387000066041946
I_ZX_bound 68.91955360412598 I_ZY_bound -0.23395922537893057


100%|██████████| 600/600 [00:22<00:00, 26.90it/s]
100%|██████████| 100/100 [00:02<00:00, 36.42it/s]
  1%|          | 4/600 [00:00<00:17, 34.13it/s]

epoch 1 loss 0.2114759872853756 prediction 0.961500004529953
I_ZX_bound 58.22603141784668 I_ZY_bound -0.15324995339848102


100%|██████████| 600/600 [00:22<00:00, 26.30it/s]
100%|██████████| 100/100 [00:03<00:00, 26.85it/s]
  0%|          | 3/600 [00:00<00:24, 24.44it/s]

epoch 2 loss 0.16913576297461985 prediction 0.9705000048875809
I_ZX_bound 51.38968978881836 I_ZY_bound -0.11774607091676444


100%|██████████| 600/600 [00:22<00:00, 26.15it/s]
100%|██████████| 100/100 [00:03<00:00, 30.54it/s]
  0%|          | 3/600 [00:00<00:20, 29.09it/s]

epoch 3 loss 0.14642569184303283 prediction 0.975900005698204
I_ZX_bound 46.94511024475098 I_ZY_bound -0.09948057903675363


100%|██████████| 600/600 [00:23<00:00, 25.58it/s]
100%|██████████| 100/100 [00:03<00:00, 26.58it/s]
  0%|          | 3/600 [00:00<00:26, 22.34it/s]

epoch 4 loss 0.13180978331714868 prediction 0.9779000055789947
I_ZX_bound 42.714550170898434 I_ZY_bound -0.0890952311269939


100%|██████████| 600/600 [00:25<00:00, 23.49it/s]
100%|██████████| 100/100 [00:03<00:00, 29.33it/s]
  1%|          | 4/600 [00:00<00:20, 28.87it/s]

epoch 5 loss 0.12013365723192691 prediction 0.980100005865097
I_ZX_bound 39.482367782592775 I_ZY_bound -0.08065128753893078


100%|██████████| 600/600 [00:25<00:00, 23.76it/s]
100%|██████████| 100/100 [00:03<00:00, 28.80it/s]
  0%|          | 3/600 [00:00<00:21, 27.34it/s]

epoch 6 loss 0.11203867081552744 prediction 0.9817000073194504
I_ZX_bound 36.46133613586426 I_ZY_bound -0.07557733255671337


100%|██████████| 600/600 [00:26<00:00, 22.70it/s]
100%|██████████| 100/100 [00:03<00:00, 26.40it/s]
  0%|          | 2/600 [00:00<00:40, 14.88it/s]

epoch 7 loss 0.1066500224545598 prediction 0.9821000075340272
I_ZX_bound 33.5832844543457 I_ZY_bound -0.07306673643062822


100%|██████████| 600/600 [00:26<00:00, 22.82it/s]
100%|██████████| 100/100 [00:03<00:00, 30.19it/s]
  0%|          | 3/600 [00:00<00:23, 25.53it/s]

epoch 8 loss 0.10387145284563302 prediction 0.9830000084638596
I_ZX_bound 31.012268295288084 I_ZY_bound -0.07285918250330724


100%|██████████| 600/600 [00:22<00:00, 26.98it/s]
100%|██████████| 100/100 [00:02<00:00, 36.73it/s]
  0%|          | 3/600 [00:00<00:28, 21.30it/s]

epoch 9 loss 0.09949753385037184 prediction 0.9845000076293945
I_ZX_bound 28.904817810058592 I_ZY_bound -0.07059271475067362


100%|██████████| 600/600 [00:25<00:00, 23.57it/s]
100%|██████████| 100/100 [00:03<00:00, 31.40it/s]
  0%|          | 3/600 [00:00<00:24, 24.61it/s]

epoch 10 loss 0.09543151129037142 prediction 0.9842000073194503
I_ZX_bound 26.508571949005127 I_ZY_bound -0.06892293772427366


100%|██████████| 600/600 [00:26<00:00, 22.72it/s]
100%|██████████| 100/100 [00:03<00:00, 28.84it/s]
  0%|          | 3/600 [00:00<00:22, 26.35it/s]

epoch 11 loss 0.09422625206410885 prediction 0.9841000068187714
I_ZX_bound 24.831278076171873 I_ZY_bound -0.06939497277024202


100%|██████████| 600/600 [00:26<00:00, 23.06it/s]
100%|██████████| 100/100 [00:03<00:00, 25.55it/s]
  0%|          | 3/600 [00:00<00:23, 25.71it/s]

epoch 12 loss 0.09237400360405446 prediction 0.9853000062704086
I_ZX_bound 22.79382152557373 I_ZY_bound -0.06958018128760159


100%|██████████| 600/600 [00:25<00:00, 23.11it/s]
100%|██████████| 100/100 [00:03<00:00, 28.01it/s]
  0%|          | 2/600 [00:00<00:32, 18.49it/s]

epoch 13 loss 0.09197237385436892 prediction 0.9861000072956085
I_ZX_bound 21.587479400634766 I_ZY_bound -0.070384893148439


100%|██████████| 600/600 [00:25<00:00, 23.11it/s]
100%|██████████| 100/100 [00:03<00:00, 32.85it/s]
  0%|          | 3/600 [00:00<00:26, 22.54it/s]

epoch 14 loss 0.09433465959504246 prediction 0.985800006389618
I_ZX_bound 20.075782833099364 I_ZY_bound -0.0742588759958744


100%|██████████| 600/600 [00:20<00:00, 29.48it/s]
100%|██████████| 100/100 [00:02<00:00, 41.15it/s]
  0%|          | 3/600 [00:00<00:23, 25.11it/s]

epoch 15 loss 0.0936714955419302 prediction 0.9864000064134598
I_ZX_bound 19.719102458953856 I_ZY_bound -0.07395239227334968


100%|██████████| 600/600 [00:19<00:00, 30.11it/s]
100%|██████████| 100/100 [00:02<00:00, 34.16it/s]
  1%|          | 4/600 [00:00<00:21, 27.73it/s]

epoch 16 loss 0.09404656577855348 prediction 0.986200007200241
I_ZX_bound 18.26131196975708 I_ZY_bound -0.0757852530234959


100%|██████████| 600/600 [00:22<00:00, 26.91it/s]
100%|██████████| 100/100 [00:02<00:00, 37.12it/s]
  0%|          | 3/600 [00:00<00:21, 27.82it/s]

epoch 17 loss 0.09638783864676953 prediction 0.985800006389618
I_ZX_bound 17.80865489959717 I_ZY_bound -0.07857918262947351


100%|██████████| 600/600 [00:19<00:00, 30.05it/s]
100%|██████████| 100/100 [00:02<00:00, 39.11it/s]
  1%|          | 4/600 [00:00<00:15, 39.14it/s]

epoch 18 loss 0.09692411048337818 prediction 0.9864000064134598
I_ZX_bound 17.00822326660156 I_ZY_bound -0.07991588647593745


100%|██████████| 600/600 [00:19<00:00, 30.64it/s]
100%|██████████| 100/100 [00:02<00:00, 35.00it/s]
  0%|          | 3/600 [00:00<00:21, 27.89it/s]

epoch 19 loss 0.11382948169484734 prediction 0.9854000073671341
I_ZX_bound 19.633427219390867 I_ZY_bound -0.09419605278410018


100%|██████████| 600/600 [00:21<00:00, 27.75it/s]
100%|██████████| 100/100 [00:03<00:00, 28.70it/s]
  0%|          | 3/600 [00:00<00:22, 25.96it/s]

epoch 20 loss 0.09809642529115081 prediction 0.9863000065088272
I_ZX_bound 16.536506814956667 I_ZY_bound -0.08155991815961898


100%|██████████| 600/600 [00:25<00:00, 23.66it/s]
100%|██████████| 100/100 [00:02<00:00, 35.76it/s]
  0%|          | 3/600 [00:00<00:25, 23.69it/s]

epoch 21 loss 0.09931245218962431 prediction 0.9862000066041946
I_ZX_bound 16.291065063476562 I_ZY_bound -0.08302138597820885


100%|██████████| 600/600 [00:20<00:00, 28.59it/s]
100%|██████████| 100/100 [00:03<00:00, 30.41it/s]
  0%|          | 2/600 [00:00<00:32, 18.35it/s]

epoch 22 loss 0.10036390479654074 prediction 0.9861000066995621
I_ZX_bound 16.037411880493163 I_ZY_bound -0.08432649301132188


100%|██████████| 600/600 [00:25<00:00, 23.88it/s]
100%|██████████| 100/100 [00:03<00:00, 28.47it/s]
  0%|          | 2/600 [00:00<00:31, 18.78it/s]

epoch 23 loss 0.10182334376499057 prediction 0.9860000067949295
I_ZX_bound 15.640594978332519 I_ZY_bound -0.08618274782202207


100%|██████████| 600/600 [00:25<00:00, 23.61it/s]
100%|██████████| 100/100 [00:03<00:00, 31.92it/s]
  1%|          | 4/600 [00:00<00:19, 29.85it/s]

epoch 24 loss 0.10525768041610718 prediction 0.9857000070810318
I_ZX_bound 15.234410657882691 I_ZY_bound -0.09002326845191419


100%|██████████| 600/600 [00:25<00:00, 23.47it/s]
100%|██████████| 100/100 [00:04<00:00, 23.10it/s]
  0%|          | 3/600 [00:00<00:22, 26.43it/s]

epoch 25 loss 0.1008784993737936 prediction 0.9865000063180923
I_ZX_bound 14.81883213043213 I_ZY_bound -0.0860596662457101


100%|██████████| 600/600 [00:29<00:00, 20.34it/s]
100%|██████████| 100/100 [00:03<00:00, 30.39it/s]
  0%|          | 3/600 [00:00<00:26, 22.53it/s]

epoch 26 loss 0.10388136509805918 prediction 0.9864000076055527
I_ZX_bound 15.143932695388793 I_ZY_bound -0.08873743150150404


100%|██████████| 600/600 [00:26<00:00, 22.25it/s]
100%|██████████| 100/100 [00:02<00:00, 33.81it/s]
  1%|          | 4/600 [00:00<00:19, 30.08it/s]

epoch 27 loss 0.10434772880747914 prediction 0.9865000063180923
I_ZX_bound 14.746550483703613 I_ZY_bound -0.0896011774148792


100%|██████████| 600/600 [00:20<00:00, 28.81it/s]
100%|██████████| 100/100 [00:02<00:00, 37.62it/s]
  0%|          | 3/600 [00:00<00:24, 24.23it/s]

epoch 28 loss 0.10359469054266811 prediction 0.9867000079154968
I_ZX_bound 14.527044086456298 I_ZY_bound -0.08906764553510584


100%|██████████| 600/600 [00:20<00:00, 28.69it/s]
100%|██████████| 100/100 [00:03<00:00, 33.20it/s]
  1%|          | 4/600 [00:00<00:18, 31.83it/s]

epoch 29 loss 0.1060618936829269 prediction 0.9860000073909759
I_ZX_bound 14.484120492935181 I_ZY_bound -0.09157777229556814


100%|██████████| 600/600 [00:20<00:00, 28.79it/s]
100%|██████████| 100/100 [00:02<00:00, 37.48it/s]
  1%|          | 4/600 [00:00<00:17, 33.28it/s]

epoch 30 loss 0.10528963964432478 prediction 0.987000008225441
I_ZX_bound 14.052966871261596 I_ZY_bound -0.09123667198698968


100%|██████████| 600/600 [00:20<00:00, 29.29it/s]
100%|██████████| 100/100 [00:03<00:00, 32.67it/s]
  0%|          | 3/600 [00:00<00:21, 27.42it/s]

epoch 31 loss 0.10916229114867747 prediction 0.9865000075101853
I_ZX_bound 14.484345960617066 I_ZY_bound -0.09467794460244477


100%|██████████| 600/600 [00:20<00:00, 29.13it/s]
100%|██████████| 100/100 [00:02<00:00, 35.77it/s]
  1%|          | 4/600 [00:00<00:18, 31.84it/s]

epoch 32 loss 0.10924249788746238 prediction 0.9857000070810318
I_ZX_bound 14.431790733337403 I_ZY_bound -0.09481070664827712


100%|██████████| 600/600 [00:20<00:00, 29.46it/s]
100%|██████████| 100/100 [00:02<00:00, 38.06it/s]
  0%|          | 3/600 [00:00<00:20, 28.74it/s]

epoch 33 loss 0.10725417457520962 prediction 0.9866000074148178
I_ZX_bound 14.323311061859131 I_ZY_bound -0.09293086224002764


100%|██████████| 600/600 [00:20<00:00, 29.04it/s]
100%|██████████| 100/100 [00:02<00:00, 34.10it/s]
  0%|          | 3/600 [00:00<00:19, 29.99it/s]

epoch 34 loss 0.11032964181154967 prediction 0.9870000064373017
I_ZX_bound 14.056461124420165 I_ZY_bound -0.09627317983831744


100%|██████████| 600/600 [00:19<00:00, 31.43it/s]
100%|██████████| 100/100 [00:02<00:00, 43.19it/s]
  1%|          | 5/600 [00:00<00:14, 40.91it/s]

epoch 35 loss 0.10994737962260842 prediction 0.9870000076293945
I_ZX_bound 13.894462871551514 I_ZY_bound -0.0960529156588018


100%|██████████| 600/600 [00:16<00:00, 36.49it/s]
100%|██████████| 100/100 [00:02<00:00, 45.43it/s]
  1%|          | 5/600 [00:00<00:15, 38.88it/s]

epoch 36 loss 0.10934069219976664 prediction 0.9863000071048736
I_ZX_bound 14.449501457214355 I_ZY_bound -0.0948911900667008


100%|██████████| 600/600 [00:16<00:00, 37.39it/s]
100%|██████████| 100/100 [00:02<00:00, 43.25it/s]
  1%|          | 4/600 [00:00<00:18, 32.67it/s]

epoch 37 loss 0.11151474934071302 prediction 0.986200008392334
I_ZX_bound 13.862007570266723 I_ZY_bound -0.0976527419744525


100%|██████████| 600/600 [00:17<00:00, 35.08it/s]
100%|██████████| 100/100 [00:02<00:00, 38.54it/s]
  1%|          | 4/600 [00:00<00:17, 33.53it/s]

epoch 38 loss 0.1154631488583982 prediction 0.9867000073194504
I_ZX_bound 13.795471324920655 I_ZY_bound -0.10166767724556848


100%|██████████| 600/600 [00:17<00:00, 34.81it/s]
100%|██████████| 100/100 [00:02<00:00, 45.11it/s]
  1%|          | 4/600 [00:00<00:16, 36.18it/s]

epoch 39 loss 0.11129227055236697 prediction 0.9865000069141387
I_ZX_bound 13.516047248840332 I_ZY_bound -0.09777622234192677


100%|██████████| 600/600 [00:16<00:00, 36.96it/s]
100%|██████████| 100/100 [00:02<00:00, 42.01it/s]
  1%|          | 5/600 [00:00<00:14, 40.53it/s]

epoch 40 loss 0.11401027765125037 prediction 0.9868000078201294
I_ZX_bound 13.776112899780273 I_ZY_bound -0.10023416382377036


100%|██████████| 600/600 [00:15<00:00, 37.82it/s]
100%|██████████| 100/100 [00:02<00:00, 43.11it/s]
  1%|          | 5/600 [00:00<00:13, 42.74it/s]

epoch 41 loss 0.11791337056085467 prediction 0.9860000073909759
I_ZX_bound 13.508601303100585 I_ZY_bound -0.10440476823248901


100%|██████████| 600/600 [00:15<00:00, 38.20it/s]
100%|██████████| 100/100 [00:02<00:00, 46.38it/s]
  1%|          | 5/600 [00:00<00:14, 41.53it/s]

epoch 42 loss 0.117327585183084 prediction 0.986700006723404
I_ZX_bound 13.29161364555359 I_ZY_bound -0.1040359714825172


100%|██████████| 600/600 [00:18<00:00, 32.19it/s]
100%|██████████| 100/100 [00:03<00:00, 28.68it/s]
  0%|          | 3/600 [00:00<00:24, 24.14it/s]

epoch 43 loss 0.11731696866452694 prediction 0.9864000070095063
I_ZX_bound 13.446302680969238 I_ZY_bound -0.10387066554976628


100%|██████████| 600/600 [00:20<00:00, 29.50it/s]
100%|██████████| 100/100 [00:03<00:00, 32.80it/s]
  0%|          | 3/600 [00:00<00:20, 29.48it/s]

epoch 44 loss 0.11597768194042146 prediction 0.9859000068902969
I_ZX_bound 13.389795045852662 I_ZY_bound -0.1025878867378924


100%|██████████| 600/600 [00:24<00:00, 24.20it/s]
100%|██████████| 100/100 [00:03<00:00, 28.24it/s]
  0%|          | 3/600 [00:00<00:20, 28.98it/s]

epoch 45 loss 0.1165429737046361 prediction 0.9866000086069107
I_ZX_bound 13.519684190750121 I_ZY_bound -0.10302328947407659


100%|██████████| 600/600 [00:20<00:00, 29.84it/s]
100%|██████████| 100/100 [00:02<00:00, 35.65it/s]
  0%|          | 3/600 [00:00<00:25, 23.75it/s]

epoch 46 loss 0.11776054455898702 prediction 0.9865000063180923
I_ZX_bound 13.66977991104126 I_ZY_bound -0.10409076477517373


100%|██████████| 600/600 [00:19<00:00, 30.43it/s]
100%|██████████| 100/100 [00:02<00:00, 35.51it/s]
  0%|          | 3/600 [00:00<00:26, 22.80it/s]

epoch 47 loss 0.11474024754017592 prediction 0.9863000071048736
I_ZX_bound 13.5531596660614 I_ZY_bound -0.1011870868614642


100%|██████████| 600/600 [00:19<00:00, 30.48it/s]
100%|██████████| 100/100 [00:02<00:00, 33.83it/s]
  1%|          | 4/600 [00:00<00:19, 30.86it/s]

epoch 48 loss 0.11740206055343151 prediction 0.9864000064134598
I_ZX_bound 13.657702169418336 I_ZY_bound -0.1037443576043006


100%|██████████| 600/600 [00:17<00:00, 34.57it/s]
100%|██████████| 100/100 [00:02<00:00, 41.22it/s]
  1%|          | 4/600 [00:00<00:17, 33.66it/s]

epoch 49 loss 0.12009605175815523 prediction 0.9871000069379806
I_ZX_bound 13.410976600646972 I_ZY_bound -0.10668507454684004


100%|██████████| 600/600 [00:16<00:00, 37.28it/s]
100%|██████████| 100/100 [00:02<00:00, 43.91it/s]
  1%|          | 4/600 [00:00<00:15, 37.36it/s]

epoch 50 loss 0.12058581265620888 prediction 0.986700006723404
I_ZX_bound 13.201490726470947 I_ZY_bound -0.10738432123791426


100%|██████████| 600/600 [00:16<00:00, 37.21it/s]
100%|██████████| 100/100 [00:02<00:00, 42.62it/s]
  1%|          | 4/600 [00:00<00:19, 31.15it/s]

epoch 51 loss 0.11943567033857107 prediction 0.9869000071287155
I_ZX_bound 13.52099669456482 I_ZY_bound -0.10591467377496883


100%|██████████| 600/600 [00:16<00:00, 37.22it/s]
100%|██████████| 100/100 [00:02<00:00, 42.67it/s]
  1%|          | 4/600 [00:00<00:16, 36.79it/s]

epoch 52 loss 0.11652369029819966 prediction 0.9869000053405762
I_ZX_bound 13.07238374710083 I_ZY_bound -0.1034513059805613


100%|██████████| 600/600 [00:15<00:00, 37.76it/s]
100%|██████████| 100/100 [00:02<00:00, 43.69it/s]
  1%|          | 4/600 [00:00<00:16, 35.96it/s]

epoch 53 loss 0.11557225733995438 prediction 0.9862000066041946
I_ZX_bound 13.403104763031006 I_ZY_bound -0.1021691518445732


100%|██████████| 600/600 [00:17<00:00, 33.65it/s]
100%|██████████| 100/100 [00:02<00:00, 38.76it/s]
  1%|          | 4/600 [00:00<00:15, 39.21it/s]

epoch 54 loss 0.11893727134913207 prediction 0.9862000066041946
I_ZX_bound 13.027967748641968 I_ZY_bound -0.10590930244303308


100%|██████████| 600/600 [00:19<00:00, 30.15it/s]
100%|██████████| 100/100 [00:03<00:00, 31.97it/s]
  1%|          | 4/600 [00:00<00:20, 29.31it/s]

epoch 55 loss 0.11746092914603651 prediction 0.9863000082969665
I_ZX_bound 13.46893075942993 I_ZY_bound -0.10399199788400437


100%|██████████| 600/600 [00:19<00:00, 30.36it/s]
100%|██████████| 100/100 [00:03<00:00, 31.25it/s]
  0%|          | 3/600 [00:00<00:24, 24.19it/s]

epoch 56 loss 0.11953869767487049 prediction 0.9869000071287155
I_ZX_bound 13.122292156219482 I_ZY_bound -0.10641640511516016


100%|██████████| 600/600 [00:19<00:00, 30.60it/s]
100%|██████████| 100/100 [00:02<00:00, 36.37it/s]
  0%|          | 3/600 [00:00<00:26, 22.42it/s]

epoch 57 loss 0.1213185156043619 prediction 0.9869000077247619
I_ZX_bound 13.284983978271484 I_ZY_bound -0.10803353082272224


100%|██████████| 600/600 [00:20<00:00, 29.85it/s]
100%|██████████| 100/100 [00:02<00:00, 34.25it/s]
  1%|          | 4/600 [00:00<00:16, 35.66it/s]

epoch 58 loss 0.12589839384891094 prediction 0.9870000070333481
I_ZX_bound 13.27673357963562 I_ZY_bound -0.11262165970751084


100%|██████████| 600/600 [00:19<00:00, 31.51it/s]
100%|██████████| 100/100 [00:02<00:00, 36.29it/s]
  1%|          | 4/600 [00:00<00:22, 26.65it/s]

epoch 59 loss 0.1243271841481328 prediction 0.986100007891655
I_ZX_bound 13.306593341827393 I_ZY_bound -0.11102059046330397


100%|██████████| 600/600 [00:18<00:00, 32.25it/s]
100%|██████████| 100/100 [00:02<00:00, 41.11it/s]
  1%|          | 4/600 [00:00<00:18, 32.59it/s]

epoch 60 loss 0.12055131581611932 prediction 0.9869000065326691
I_ZX_bound 13.362929239273072 I_ZY_bound -0.10718838613247499


100%|██████████| 600/600 [00:15<00:00, 37.62it/s]
100%|██████████| 100/100 [00:02<00:00, 45.49it/s]
  1%|          | 4/600 [00:00<00:15, 39.35it/s]

epoch 61 loss 0.12031872222200037 prediction 0.9864000070095063
I_ZX_bound 13.170260362625122 I_ZY_bound -0.10714846111484803


100%|██████████| 600/600 [00:16<00:00, 37.40it/s]
100%|██████████| 100/100 [00:02<00:00, 41.80it/s]
  1%|          | 4/600 [00:00<00:15, 37.64it/s]

epoch 62 loss 0.12334042978473007 prediction 0.9867000073194504
I_ZX_bound 13.11648238182068 I_ZY_bound -0.1102239467727486


100%|██████████| 600/600 [00:15<00:00, 37.91it/s]
100%|██████████| 100/100 [00:02<00:00, 45.47it/s]
  0%|          | 3/600 [00:00<00:25, 23.13it/s]

epoch 63 loss 0.12273948408663272 prediction 0.9871000069379806
I_ZX_bound 13.175607089996339 I_ZY_bound -0.10956387615064159


100%|██████████| 600/600 [00:15<00:00, 37.81it/s]
100%|██████████| 100/100 [00:02<00:00, 41.59it/s]
  1%|          | 4/600 [00:00<00:17, 33.81it/s]

epoch 64 loss 0.11944345598109067 prediction 0.9868000066280365
I_ZX_bound 12.901895761489868 I_ZY_bound -0.10654155950061978


100%|██████████| 600/600 [00:17<00:00, 33.55it/s]
100%|██████████| 100/100 [00:02<00:00, 38.18it/s]
  1%|          | 4/600 [00:00<00:19, 30.81it/s]

epoch 65 loss 0.11852013508789241 prediction 0.9869000077247619
I_ZX_bound 12.907855262756348 I_ZY_bound -0.10561227911151945


100%|██████████| 600/600 [00:19<00:00, 31.52it/s]
100%|██████████| 100/100 [00:03<00:00, 32.97it/s]
  1%|          | 4/600 [00:00<00:17, 33.19it/s]

epoch 66 loss 0.11700570408254862 prediction 0.9874000066518783
I_ZX_bound 13.04040761947632 I_ZY_bound -0.10396529638906941


100%|██████████| 600/600 [00:20<00:00, 29.44it/s]
100%|██████████| 100/100 [00:02<00:00, 33.69it/s]
  0%|          | 3/600 [00:00<00:24, 24.39it/s]

epoch 67 loss 0.12305420718155802 prediction 0.9872000056505204
I_ZX_bound 13.163172416687011 I_ZY_bound -0.10989103352185339


100%|██████████| 600/600 [00:20<00:00, 28.99it/s]
100%|██████████| 100/100 [00:02<00:00, 37.26it/s]
  1%|          | 4/600 [00:00<00:16, 37.02it/s]

epoch 68 loss 0.12118415392935276 prediction 0.9872000074386597
I_ZX_bound 13.0555611038208 I_ZY_bound -0.10812859242665582


100%|██████████| 600/600 [00:18<00:00, 31.63it/s]
100%|██████████| 100/100 [00:03<00:00, 31.95it/s]
  1%|          | 5/600 [00:00<00:14, 40.16it/s]

epoch 69 loss 0.12656011375598608 prediction 0.9874000060558319
I_ZX_bound 13.237693748474122 I_ZY_bound -0.11332241904688999


100%|██████████| 600/600 [00:19<00:00, 30.37it/s]
100%|██████████| 100/100 [00:02<00:00, 38.70it/s]
  1%|          | 4/600 [00:00<00:15, 39.69it/s]

epoch 70 loss 0.12136136066168547 prediction 0.9873000067472458
I_ZX_bound 12.910532274246215 I_ZY_bound -0.10845082776271738


100%|██████████| 600/600 [00:18<00:00, 32.39it/s]
100%|██████████| 100/100 [00:02<00:00, 43.17it/s]
  1%|          | 4/600 [00:00<00:15, 39.44it/s]

epoch 71 loss 0.12022037165239453 prediction 0.9873000073432923
I_ZX_bound 12.886890077590943 I_ZY_bound -0.10733348106965422


100%|██████████| 600/600 [00:15<00:00, 37.78it/s]
100%|██████████| 100/100 [00:02<00:00, 44.83it/s]
  1%|          | 4/600 [00:00<00:17, 33.36it/s]

epoch 72 loss 0.12327627147547901 prediction 0.9866000056266785
I_ZX_bound 13.284911117553712 I_ZY_bound -0.10999136004305911


100%|██████████| 600/600 [00:16<00:00, 36.34it/s]
100%|██████████| 100/100 [00:02<00:00, 41.29it/s]
  1%|          | 4/600 [00:00<00:16, 36.64it/s]

epoch 73 loss 0.12031287874095142 prediction 0.9871000069379806
I_ZX_bound 13.090542488098144 I_ZY_bound -0.10722233598644379


100%|██████████| 600/600 [00:15<00:00, 38.87it/s]
100%|██████████| 100/100 [00:02<00:00, 43.11it/s]
  1%|          | 4/600 [00:00<00:15, 39.57it/s]

epoch 74 loss 0.12100301362574101 prediction 0.9874000066518783
I_ZX_bound 13.089945316314697 I_ZY_bound -0.10791306759114377


100%|██████████| 600/600 [00:15<00:00, 37.51it/s]
100%|██████████| 100/100 [00:02<00:00, 42.49it/s]
  1%|          | 4/600 [00:00<00:17, 33.48it/s]

epoch 75 loss 0.12205177891999483 prediction 0.986700006723404
I_ZX_bound 13.275553531646729 I_ZY_bound -0.10877622442436405


100%|██████████| 600/600 [00:17<00:00, 34.45it/s]
100%|██████████| 100/100 [00:02<00:00, 39.27it/s]
  1%|          | 4/600 [00:00<00:18, 32.18it/s]

epoch 76 loss 0.12280007225461304 prediction 0.9866000074148178
I_ZX_bound 13.207942533493043 I_ZY_bound -0.10959212880698033


100%|██████████| 600/600 [00:19<00:00, 31.40it/s]
100%|██████████| 100/100 [00:02<00:00, 35.49it/s]
  1%|          | 4/600 [00:00<00:16, 37.20it/s]

epoch 77 loss 0.12329933570697904 prediction 0.9875000071525574
I_ZX_bound 13.115045976638793 I_ZY_bound -0.11018428939860314


100%|██████████| 600/600 [00:19<00:00, 31.56it/s]
100%|██████████| 100/100 [00:03<00:00, 32.06it/s]
  0%|          | 2/600 [00:00<00:32, 18.51it/s]

epoch 78 loss 0.1223076005000621 prediction 0.9874000060558319
I_ZX_bound 12.710228395462035 I_ZY_bound -0.10959737110591959


100%|██████████| 600/600 [00:19<00:00, 31.31it/s]
100%|██████████| 100/100 [00:03<00:00, 32.68it/s]
  0%|          | 2/600 [00:00<00:35, 16.68it/s]

epoch 79 loss 0.12200935051776468 prediction 0.9875000065565109
I_ZX_bound 13.123225955963134 I_ZY_bound -0.10888612368435133


100%|██████████| 600/600 [00:19<00:00, 30.30it/s]
100%|██████████| 100/100 [00:02<00:00, 34.69it/s]
  1%|          | 4/600 [00:00<00:16, 35.66it/s]

epoch 80 loss 0.12173838037997484 prediction 0.9877000075578689
I_ZX_bound 12.576407842636108 I_ZY_bound -0.10916197226149961


100%|██████████| 600/600 [00:19<00:00, 30.18it/s]
100%|██████████| 100/100 [00:02<00:00, 38.75it/s]
  1%|          | 4/600 [00:00<00:15, 38.63it/s]

epoch 81 loss 0.12556245320476592 prediction 0.9866000062227249
I_ZX_bound 13.073527765274047 I_ZY_bound -0.11248892539879307


100%|██████████| 600/600 [00:19<00:00, 31.50it/s]
100%|██████████| 100/100 [00:02<00:00, 43.71it/s]
  1%|          | 5/600 [00:00<00:14, 41.11it/s]

epoch 82 loss 0.12452108846977353 prediction 0.9872000074386597
I_ZX_bound 13.036622171401978 I_ZY_bound -0.11148446584993507


100%|██████████| 600/600 [00:16<00:00, 37.06it/s]
100%|██████████| 100/100 [00:02<00:00, 42.46it/s]
  1%|          | 4/600 [00:00<00:16, 36.15it/s]

epoch 83 loss 0.12169769402593374 prediction 0.9872000074386597
I_ZX_bound 12.880034322738647 I_ZY_bound -0.1088176590151852


100%|██████████| 600/600 [00:16<00:00, 35.75it/s]
100%|██████████| 100/100 [00:02<00:00, 42.72it/s]
  1%|          | 5/600 [00:00<00:14, 40.77it/s]

epoch 84 loss 0.12492094788700342 prediction 0.9862000066041946
I_ZX_bound 12.863745927810669 I_ZY_bound -0.11205720194615423


100%|██████████| 600/600 [00:16<00:00, 36.66it/s]
100%|██████████| 100/100 [00:02<00:00, 47.49it/s]
  1%|          | 5/600 [00:00<00:14, 42.37it/s]

epoch 85 loss 0.1250776065327227 prediction 0.9870000070333481
I_ZX_bound 12.721849374771118 I_ZY_bound -0.11235575629165397


100%|██████████| 600/600 [00:15<00:00, 37.51it/s]
100%|██████████| 100/100 [00:02<00:00, 42.57it/s]
  1%|          | 4/600 [00:00<00:15, 37.58it/s]

epoch 86 loss 0.126512562148273 prediction 0.9866000074148178
I_ZX_bound 12.729757432937623 I_ZY_bound -0.11378280441800598


100%|██████████| 600/600 [00:16<00:00, 36.60it/s]
100%|██████████| 100/100 [00:02<00:00, 42.40it/s]
  1%|          | 4/600 [00:00<00:14, 39.94it/s]

epoch 87 loss 0.12228766970336437 prediction 0.9871000069379806
I_ZX_bound 12.74198049545288 I_ZY_bound -0.1095456884300802


100%|██████████| 600/600 [00:16<00:00, 35.98it/s]
100%|██████████| 100/100 [00:02<00:00, 34.08it/s]
  0%|          | 3/600 [00:00<00:21, 27.65it/s]

epoch 88 loss 0.12873406000435353 prediction 0.9871000051498413
I_ZX_bound 13.017879133224488 I_ZY_bound -0.1157161794055719


100%|██████████| 600/600 [00:22<00:00, 26.12it/s]
100%|██████████| 100/100 [00:03<00:00, 27.34it/s]
  0%|          | 3/600 [00:00<00:20, 28.69it/s]

epoch 89 loss 0.12657871479168534 prediction 0.9873000079393387
I_ZX_bound 12.827550592422485 I_ZY_bound -0.11375116332841571


100%|██████████| 600/600 [00:24<00:00, 24.38it/s]
100%|██████████| 100/100 [00:03<00:00, 27.82it/s]
  0%|          | 2/600 [00:00<00:35, 16.70it/s]

epoch 90 loss 0.1272166211437434 prediction 0.9872000080347061
I_ZX_bound 12.768160705566407 I_ZY_bound -0.1144484598300187


100%|██████████| 600/600 [00:24<00:00, 24.02it/s]
100%|██████████| 100/100 [00:03<00:00, 30.44it/s]
  0%|          | 3/600 [00:00<00:25, 23.23it/s]

epoch 91 loss 0.12632377865724265 prediction 0.9870000064373017
I_ZX_bound 12.589463024139404 I_ZY_bound -0.1137343148101354


100%|██████████| 600/600 [00:25<00:00, 23.09it/s]
100%|██████████| 100/100 [00:03<00:00, 28.13it/s]
  0%|          | 2/600 [00:00<00:32, 18.17it/s]

epoch 92 loss 0.12575174106284975 prediction 0.9868000072240829
I_ZX_bound 12.950599021911621 I_ZY_bound -0.11280114190711174


100%|██████████| 600/600 [00:29<00:00, 20.37it/s]
100%|██████████| 100/100 [00:02<00:00, 35.66it/s]
  1%|          | 4/600 [00:00<00:19, 30.19it/s]

epoch 93 loss 0.12866748785600066 prediction 0.9875000071525574
I_ZX_bound 12.660176076889037 I_ZY_bound -0.11600731072365306


100%|██████████| 600/600 [00:20<00:00, 29.74it/s]
100%|██████████| 100/100 [00:02<00:00, 35.77it/s]
  1%|          | 4/600 [00:00<00:19, 31.34it/s]

epoch 94 loss 0.12618200016207992 prediction 0.9872000068426132
I_ZX_bound 12.692986793518067 I_ZY_bound -0.11348901247023604


100%|██████████| 600/600 [00:20<00:00, 29.37it/s]
100%|██████████| 100/100 [00:02<00:00, 38.60it/s]
  0%|          | 3/600 [00:00<00:21, 28.02it/s]

epoch 95 loss 0.12342556579038501 prediction 0.9873000061511994
I_ZX_bound 12.776312294006347 I_ZY_bound -0.11064925259153824


100%|██████████| 600/600 [00:21<00:00, 28.50it/s]
100%|██████████| 100/100 [00:02<00:00, 34.78it/s]
  0%|          | 3/600 [00:00<00:21, 28.40it/s]

epoch 96 loss 0.12337274231947959 prediction 0.9875000059604645
I_ZX_bound 12.691501293182373 I_ZY_bound -0.11068124018609524


100%|██████████| 600/600 [00:20<00:00, 29.57it/s]
100%|██████████| 100/100 [00:02<00:00, 37.34it/s]
  1%|          | 4/600 [00:00<00:18, 32.32it/s]

epoch 97 loss 0.12723039242438972 prediction 0.9870000058412551
I_ZX_bound 12.71415428161621 I_ZY_bound -0.11451623709406704


100%|██████████| 600/600 [00:20<00:00, 29.53it/s]
100%|██████████| 100/100 [00:02<00:00, 38.95it/s]
  0%|          | 3/600 [00:00<00:21, 27.90it/s]

epoch 98 loss 0.12594788867980242 prediction 0.987000008225441
I_ZX_bound 12.833197135925293 I_ZY_bound -0.11311469085747376


100%|██████████| 600/600 [00:20<00:00, 29.08it/s]
100%|██████████| 100/100 [00:02<00:00, 35.50it/s]
  0%|          | 3/600 [00:00<00:21, 28.21it/s]

epoch 99 loss 0.12482848264276981 prediction 0.9871000069379806
I_ZX_bound 12.498340578079224 I_ZY_bound -0.11233014142955654


100%|██████████| 600/600 [00:20<00:00, 29.14it/s]
100%|██████████| 100/100 [00:02<00:00, 34.94it/s]
  1%|          | 4/600 [00:00<00:19, 30.16it/s]

epoch 100 loss 0.1258250700496137 prediction 0.9876000064611435
I_ZX_bound 12.94005319595337 I_ZY_bound -0.11288501606788487


100%|██████████| 600/600 [00:21<00:00, 27.92it/s]
100%|██████████| 100/100 [00:02<00:00, 36.18it/s]
  1%|          | 4/600 [00:00<00:18, 31.95it/s]

epoch 101 loss 0.12683631662279368 prediction 0.9864000052213668
I_ZX_bound 12.488454370498657 I_ZY_bound -0.11434786112222355


100%|██████████| 600/600 [00:20<00:00, 29.63it/s]
100%|██████████| 100/100 [00:02<00:00, 34.92it/s]
  0%|          | 2/600 [00:00<00:30, 19.67it/s]

epoch 102 loss 0.12995397854596377 prediction 0.9868000072240829
I_ZX_bound 12.928567609786988 I_ZY_bound -0.11702540999103803


100%|██████████| 600/600 [00:23<00:00, 25.79it/s]
100%|██████████| 100/100 [00:03<00:00, 32.50it/s]
  0%|          | 3/600 [00:00<00:24, 24.11it/s]

epoch 103 loss 0.12828656078316272 prediction 0.9863000059127808
I_ZX_bound 12.640429744720459 I_ZY_bound -0.1156461304309778


100%|██████████| 600/600 [00:24<00:00, 24.59it/s]
100%|██████████| 100/100 [00:02<00:00, 34.30it/s]
  1%|          | 4/600 [00:00<00:15, 37.62it/s]

epoch 104 loss 0.12575430622324346 prediction 0.9871000069379806
I_ZX_bound 12.776403894424439 I_ZY_bound -0.11297790154116229


100%|██████████| 600/600 [00:19<00:00, 30.59it/s]
100%|██████████| 100/100 [00:03<00:00, 32.94it/s]
  1%|          | 4/600 [00:00<00:16, 37.11it/s]

epoch 105 loss 0.1302864584606141 prediction 0.9864000064134598
I_ZX_bound 12.764199352264404 I_ZY_bound -0.11752225845819339


100%|██████████| 600/600 [00:20<00:00, 29.47it/s]
100%|██████████| 100/100 [00:02<00:00, 34.56it/s]
  1%|          | 4/600 [00:00<00:16, 36.87it/s]

epoch 106 loss 0.1281779138557613 prediction 0.9865000069141387
I_ZX_bound 12.92426033973694 I_ZY_bound -0.11525365299487021


100%|██████████| 600/600 [00:20<00:00, 28.62it/s]
100%|██████████| 100/100 [00:03<00:00, 32.89it/s]
  1%|          | 4/600 [00:00<00:16, 36.45it/s]

epoch 107 loss 0.12538210103288294 prediction 0.9867000073194504
I_ZX_bound 12.688880186080933 I_ZY_bound -0.11269322021573316


100%|██████████| 600/600 [00:22<00:00, 26.72it/s]
100%|██████████| 100/100 [00:03<00:00, 28.14it/s]
  1%|          | 4/600 [00:00<00:20, 28.98it/s]

epoch 108 loss 0.1302713547088206 prediction 0.9870000058412551
I_ZX_bound 12.725225076675414 I_ZY_bound -0.11754612883727532


100%|██████████| 600/600 [00:16<00:00, 36.61it/s]
100%|██████████| 100/100 [00:02<00:00, 39.28it/s]
  0%|          | 3/600 [00:00<00:20, 29.47it/s]

epoch 109 loss 0.13136003512889147 prediction 0.9869000071287155
I_ZX_bound 12.388127975463867 I_ZY_bound -0.11897190602961928


100%|██████████| 600/600 [00:16<00:00, 36.84it/s]
100%|██████████| 100/100 [00:02<00:00, 36.63it/s]
  1%|          | 4/600 [00:00<00:20, 29.52it/s]

epoch 110 loss 0.1285318631771952 prediction 0.9863000065088272
I_ZX_bound 12.612406549453736 I_ZY_bound -0.11591945574909915


100%|██████████| 600/600 [00:21<00:00, 28.50it/s]
100%|██████████| 100/100 [00:02<00:00, 37.37it/s]
  1%|          | 4/600 [00:00<00:18, 31.59it/s]

epoch 111 loss 0.12799977350980044 prediction 0.9871000057458877
I_ZX_bound 12.672404079437255 I_ZY_bound -0.11532736871624366


100%|██████████| 600/600 [00:20<00:00, 29.14it/s]
100%|██████████| 100/100 [00:02<00:00, 36.74it/s]
  0%|          | 3/600 [00:00<00:25, 23.73it/s]

epoch 112 loss 0.12649156274273993 prediction 0.9873000055551528
I_ZX_bound 12.74425539970398 I_ZY_bound -0.1137473063141806


100%|██████████| 600/600 [00:20<00:00, 29.06it/s]
100%|██████████| 100/100 [00:02<00:00, 35.57it/s]
  1%|          | 4/600 [00:00<00:18, 33.10it/s]

epoch 113 loss 0.13093089199624955 prediction 0.986700006723404
I_ZX_bound 12.581930665969848 I_ZY_bound -0.11834896091837437


100%|██████████| 600/600 [00:20<00:00, 28.66it/s]
100%|██████████| 100/100 [00:02<00:00, 37.64it/s]
  0%|          | 3/600 [00:00<00:20, 29.50it/s]

epoch 114 loss 0.12482846767641603 prediction 0.9871000063419342
I_ZX_bound 12.452203903198242 I_ZY_bound -0.11237626336631365


100%|██████████| 600/600 [00:20<00:00, 28.76it/s]
100%|██████████| 100/100 [00:02<00:00, 34.26it/s]
  1%|          | 4/600 [00:00<00:18, 32.93it/s]

epoch 115 loss 0.1297423273138702 prediction 0.9866000062227249
I_ZX_bound 12.721144304275512 I_ZY_bound -0.11702118246466853


100%|██████████| 600/600 [00:20<00:00, 28.64it/s]
100%|██████████| 100/100 [00:02<00:00, 36.37it/s]
  0%|          | 3/600 [00:00<00:22, 26.35it/s]

epoch 116 loss 0.13038179045543075 prediction 0.9870000058412551
I_ZX_bound 12.654858446121215 I_ZY_bound -0.1177269313216675


100%|██████████| 600/600 [00:24<00:00, 25.00it/s]
100%|██████████| 100/100 [00:03<00:00, 26.73it/s]
  0%|          | 3/600 [00:00<00:23, 25.46it/s]

epoch 117 loss 0.1274623217340559 prediction 0.9868000072240829
I_ZX_bound 12.571797409057616 I_ZY_bound -0.11489052313438151


100%|██████████| 600/600 [00:24<00:00, 24.72it/s]
100%|██████████| 100/100 [00:03<00:00, 29.10it/s]
  0%|          | 2/600 [00:00<00:33, 18.01it/s]

epoch 118 loss 0.12620663759298623 prediction 0.9873000073432923
I_ZX_bound 12.56316328048706 I_ZY_bound -0.11364347295835614


100%|██████████| 600/600 [00:25<00:00, 23.47it/s]
100%|██████████| 100/100 [00:03<00:00, 29.19it/s]
  0%|          | 3/600 [00:00<00:25, 23.82it/s]

epoch 119 loss 0.12576770934276282 prediction 0.9873000067472458
I_ZX_bound 12.518410272598267 I_ZY_bound -0.11324929802853148


100%|██████████| 600/600 [00:23<00:00, 25.40it/s]
100%|██████████| 100/100 [00:03<00:00, 28.42it/s]
  0%|          | 3/600 [00:00<00:22, 26.64it/s]

epoch 120 loss 0.13047368448227645 prediction 0.987100007534027
I_ZX_bound 12.669709758758545 I_ZY_bound -0.11780397351481951


100%|██████████| 600/600 [00:26<00:00, 23.03it/s]
100%|██████████| 100/100 [00:03<00:00, 25.14it/s]
  0%|          | 3/600 [00:00<00:24, 24.40it/s]

epoch 121 loss 0.12709210662171244 prediction 0.9870000064373017
I_ZX_bound 12.580193290710449 I_ZY_bound -0.11451191351923626


100%|██████████| 600/600 [00:23<00:00, 25.80it/s]
100%|██████████| 100/100 [00:02<00:00, 33.61it/s]
  0%|          | 3/600 [00:00<00:21, 27.45it/s]

epoch 122 loss 0.12963737009093165 prediction 0.9875000065565109
I_ZX_bound 12.82365704536438 I_ZY_bound -0.11681371187209151


100%|██████████| 600/600 [00:21<00:00, 28.35it/s]
100%|██████████| 100/100 [00:02<00:00, 38.24it/s]
  0%|          | 3/600 [00:00<00:20, 29.38it/s]

epoch 123 loss 0.12894830912351607 prediction 0.9877000069618225
I_ZX_bound 12.629344997406006 I_ZY_bound -0.1163189634931041


100%|██████████| 600/600 [00:21<00:00, 28.30it/s]
100%|██████████| 100/100 [00:02<00:00, 35.14it/s]
  0%|          | 3/600 [00:00<00:20, 29.06it/s]

epoch 124 loss 0.1275044786091894 prediction 0.9879000067710877
I_ZX_bound 12.440431070327758 I_ZY_bound -0.11506404654472135


100%|██████████| 600/600 [00:20<00:00, 28.68it/s]
100%|██████████| 100/100 [00:02<00:00, 36.17it/s]
  0%|          | 3/600 [00:00<00:21, 27.81it/s]

epoch 125 loss 0.12971506169065833 prediction 0.9873000073432923
I_ZX_bound 12.813462247848511 I_ZY_bound -0.11690159918391146


100%|██████████| 600/600 [00:20<00:00, 28.73it/s]
100%|██████████| 100/100 [00:02<00:00, 41.84it/s]
  1%|          | 4/600 [00:00<00:15, 37.77it/s]

epoch 126 loss 0.1281404308695346 prediction 0.9868000066280365
I_ZX_bound 12.583525657653809 I_ZY_bound -0.11555690398381557


100%|██████████| 600/600 [00:16<00:00, 37.21it/s]
100%|██████████| 100/100 [00:02<00:00, 41.19it/s]
  1%|          | 4/600 [00:00<00:16, 36.67it/s]

epoch 127 loss 0.1262954639457166 prediction 0.9871000069379806
I_ZX_bound 12.629524555206299 I_ZY_bound -0.11366593890124932


100%|██████████| 600/600 [00:16<00:00, 35.48it/s]
100%|██████████| 100/100 [00:02<00:00, 41.48it/s]
  1%|          | 4/600 [00:00<00:15, 37.97it/s]

epoch 128 loss 0.12981270999647676 prediction 0.9869000059366226
I_ZX_bound 12.625120534896851 I_ZY_bound -0.11718758902454283


100%|██████████| 600/600 [00:16<00:00, 35.32it/s]
100%|██████████| 100/100 [00:02<00:00, 44.33it/s]
  1%|          | 4/600 [00:00<00:17, 33.36it/s]

epoch 129 loss 0.1290419678017497 prediction 0.9865000069141387
I_ZX_bound 12.600843420028687 I_ZY_bound -0.11644112395413686


100%|██████████| 600/600 [00:16<00:00, 36.57it/s]
100%|██████████| 100/100 [00:02<00:00, 43.61it/s]
  0%|          | 3/600 [00:00<00:21, 28.38it/s]

epoch 130 loss 0.1272141896188259 prediction 0.9875000065565109
I_ZX_bound 12.569514780044555 I_ZY_bound -0.1146446746296715


100%|██████████| 600/600 [00:16<00:00, 35.62it/s]
100%|██████████| 100/100 [00:02<00:00, 44.27it/s]
  1%|          | 4/600 [00:00<00:16, 36.65it/s]

epoch 131 loss 0.12770430533215404 prediction 0.9877000057697296
I_ZX_bound 12.435105209350587 I_ZY_bound -0.1152691999002127


100%|██████████| 600/600 [00:21<00:00, 27.94it/s]
100%|██████████| 100/100 [00:03<00:00, 30.52it/s]
  1%|          | 4/600 [00:00<00:18, 31.37it/s]

epoch 132 loss 0.1294668872654438 prediction 0.98760000705719
I_ZX_bound 12.5672159576416 I_ZY_bound -0.11689967124722898


100%|██████████| 600/600 [00:18<00:00, 31.64it/s]
100%|██████████| 100/100 [00:03<00:00, 31.74it/s]
  1%|          | 4/600 [00:00<00:18, 32.36it/s]

epoch 133 loss 0.12771357340738176 prediction 0.9870000076293945
I_ZX_bound 12.435825510025024 I_ZY_bound -0.11527774772956036


100%|██████████| 600/600 [00:20<00:00, 29.38it/s]
100%|██████████| 100/100 [00:02<00:00, 38.74it/s]
  1%|          | 4/600 [00:00<00:22, 26.49it/s]

epoch 134 loss 0.12863029344938695 prediction 0.9864000064134598
I_ZX_bound 12.409778289794922 I_ZY_bound -0.11622051461134106


100%|██████████| 600/600 [00:26<00:00, 22.97it/s]
100%|██████████| 100/100 [00:03<00:00, 28.91it/s]
  0%|          | 3/600 [00:00<00:22, 26.52it/s]

epoch 135 loss 0.1287550926115364 prediction 0.9871000051498413
I_ZX_bound 12.52868462562561 I_ZY_bound -0.11622640719288028


100%|██████████| 600/600 [00:25<00:00, 23.76it/s]
100%|██████████| 100/100 [00:03<00:00, 28.14it/s]
  0%|          | 3/600 [00:00<00:22, 26.01it/s]

epoch 136 loss 0.12940215296112 prediction 0.986700006723404
I_ZX_bound 12.353473548889161 I_ZY_bound -0.11704867934458889


100%|██████████| 600/600 [00:23<00:00, 25.43it/s]
100%|██████████| 100/100 [00:02<00:00, 37.63it/s]
  0%|          | 3/600 [00:00<00:24, 24.44it/s]

epoch 137 loss 0.12969634173437952 prediction 0.986700006723404
I_ZX_bound 12.30874680519104 I_ZY_bound -0.11738759361382108


100%|██████████| 600/600 [00:20<00:00, 29.41it/s]
100%|██████████| 100/100 [00:02<00:00, 37.12it/s]
  0%|          | 3/600 [00:00<00:24, 24.33it/s]

epoch 138 loss 0.1309008294157684 prediction 0.9870000070333481
I_ZX_bound 12.506996183395385 I_ZY_bound -0.11839383241836914


100%|██████████| 600/600 [00:20<00:00, 28.93it/s]
100%|██████████| 100/100 [00:02<00:00, 35.77it/s]
  0%|          | 3/600 [00:00<00:23, 25.21it/s]

epoch 139 loss 0.13078089312650262 prediction 0.9866000062227249
I_ZX_bound 12.61844225883484 I_ZY_bound -0.11816245034628083


100%|██████████| 600/600 [00:20<00:00, 28.85it/s]
100%|██████████| 100/100 [00:02<00:00, 38.87it/s]
  1%|          | 4/600 [00:00<00:19, 30.53it/s]

epoch 140 loss 0.1297046753950417 prediction 0.9870000064373017
I_ZX_bound 12.428430557250977 I_ZY_bound -0.11727624435559847


100%|██████████| 600/600 [00:20<00:00, 29.90it/s]
100%|██████████| 100/100 [00:02<00:00, 37.80it/s]
  1%|          | 4/600 [00:00<00:18, 32.26it/s]

epoch 141 loss 0.13034743991680442 prediction 0.9868000054359436
I_ZX_bound 12.602331895828247 I_ZY_bound -0.11774510788731277


100%|██████████| 600/600 [00:20<00:00, 28.87it/s]
100%|██████████| 100/100 [00:02<00:00, 36.35it/s]
  1%|          | 4/600 [00:00<00:19, 31.36it/s]

epoch 142 loss 0.13132822851650416 prediction 0.9869000077247619
I_ZX_bound 12.539250917434693 I_ZY_bound -0.1187889767769957


100%|██████████| 600/600 [00:22<00:00, 26.90it/s]
100%|██████████| 100/100 [00:02<00:00, 36.67it/s]
  0%|          | 3/600 [00:00<00:21, 27.63it/s]

epoch 143 loss 0.13015869591385126 prediction 0.9869000077247619
I_ZX_bound 12.450511856079101 I_ZY_bound -0.11770818341174163


100%|██████████| 600/600 [00:20<00:00, 29.62it/s]
100%|██████████| 100/100 [00:02<00:00, 36.13it/s]
  1%|          | 4/600 [00:00<00:18, 32.21it/s]

epoch 144 loss 0.12999429383315145 prediction 0.9869000077247619
I_ZX_bound 12.475569267272949 I_ZY_bound -0.1175187237560749


100%|██████████| 600/600 [00:20<00:00, 28.99it/s]
100%|██████████| 100/100 [00:02<00:00, 37.70it/s]
  0%|          | 3/600 [00:00<00:20, 28.69it/s]

epoch 145 loss 0.13072982958517967 prediction 0.9868000072240829
I_ZX_bound 12.417693090438842 I_ZY_bound -0.11831213629047851


100%|██████████| 600/600 [00:20<00:00, 29.15it/s]
100%|██████████| 100/100 [00:02<00:00, 37.83it/s]
  1%|          | 4/600 [00:00<00:19, 29.91it/s]

epoch 146 loss 0.12938518566079438 prediction 0.9870000076293945
I_ZX_bound 12.479768924713134 I_ZY_bound -0.11690541623509489


100%|██████████| 600/600 [00:21<00:00, 27.83it/s]
100%|██████████| 100/100 [00:03<00:00, 28.02it/s]
  0%|          | 3/600 [00:00<00:31, 19.22it/s]

epoch 147 loss 0.13139229832217098 prediction 0.9865000069141387
I_ZX_bound 12.537410430908203 I_ZY_bound -0.11885488716012332


100%|██████████| 600/600 [00:25<00:00, 23.57it/s]
100%|██████████| 100/100 [00:03<00:00, 29.74it/s]
  1%|          | 4/600 [00:00<00:19, 30.90it/s]

epoch 148 loss 0.13187381255440414 prediction 0.9871000057458877
I_ZX_bound 12.598097295761109 I_ZY_bound -0.1192757147719385


100%|██████████| 600/600 [00:25<00:00, 23.64it/s]
100%|██████████| 100/100 [00:03<00:00, 25.89it/s]
  0%|          | 2/600 [00:00<00:31, 18.83it/s]

epoch 149 loss 0.1320219702925533 prediction 0.9868000072240829
I_ZX_bound 12.497786464691162 I_ZY_bound -0.1195241832640022


100%|██████████| 600/600 [00:24<00:00, 24.71it/s]
100%|██████████| 100/100 [00:03<00:00, 32.19it/s]
  1%|          | 4/600 [00:00<00:19, 30.44it/s]

epoch 150 loss 0.12986511036753653 prediction 0.9865000069141387
I_ZX_bound 12.30040246963501 I_ZY_bound -0.11756470704916865


100%|██████████| 600/600 [00:19<00:00, 30.18it/s]
100%|██████████| 100/100 [00:02<00:00, 36.23it/s]
  1%|          | 5/600 [00:00<00:14, 41.19it/s]

epoch 151 loss 0.12938205040059983 prediction 0.9869000065326691
I_ZX_bound 12.577552089691162 I_ZY_bound -0.11680449786770623


100%|██████████| 600/600 [00:19<00:00, 30.87it/s]
100%|██████████| 100/100 [00:03<00:00, 33.03it/s]
  0%|          | 3/600 [00:00<00:20, 29.25it/s]

epoch 152 loss 0.13288932479918003 prediction 0.9868000072240829
I_ZX_bound 12.457277126312256 I_ZY_bound -0.1204320473386906


100%|██████████| 600/600 [00:17<00:00, 33.86it/s]
100%|██████████| 100/100 [00:02<00:00, 44.04it/s]
  1%|          | 5/600 [00:00<00:14, 41.30it/s]

epoch 153 loss 0.13056164705194534 prediction 0.9868000054359436
I_ZX_bound 12.574729442596436 I_ZY_bound -0.11798691648436943


100%|██████████| 600/600 [00:16<00:00, 37.28it/s]
100%|██████████| 100/100 [00:02<00:00, 45.30it/s]
  1%|          | 4/600 [00:00<00:15, 38.45it/s]

epoch 154 loss 0.12881723696365951 prediction 0.9871000069379806
I_ZX_bound 12.394339485168457 I_ZY_bound -0.11642289593757596


100%|██████████| 600/600 [00:15<00:00, 37.61it/s]
100%|██████████| 100/100 [00:02<00:00, 42.09it/s]
  0%|          | 3/600 [00:00<00:22, 26.92it/s]

epoch 155 loss 0.12954459632746876 prediction 0.9873000067472458
I_ZX_bound 12.349160785675048 I_ZY_bound -0.11719543464714662


100%|██████████| 600/600 [00:17<00:00, 34.65it/s]
100%|██████████| 100/100 [00:02<00:00, 43.74it/s]
  1%|          | 4/600 [00:00<00:15, 39.21it/s]

epoch 156 loss 0.1311364653520286 prediction 0.9873000067472458
I_ZX_bound 12.413202095031739 I_ZY_bound -0.11872326208162122


100%|██████████| 600/600 [00:16<00:00, 36.95it/s]
100%|██████████| 100/100 [00:02<00:00, 40.74it/s]
  1%|          | 4/600 [00:00<00:15, 37.58it/s]

epoch 157 loss 0.13015943904407323 prediction 0.9866000068187714
I_ZX_bound 12.510103845596314 I_ZY_bound -0.11764933402446331


100%|██████████| 600/600 [00:17<00:00, 33.60it/s]
100%|██████████| 100/100 [00:02<00:00, 37.42it/s]
  1%|          | 4/600 [00:00<00:16, 35.53it/s]

epoch 158 loss 0.13087901490740478 prediction 0.9870000070333481
I_ZX_bound 12.562452421188354 I_ZY_bound -0.11831656206864864


100%|██████████| 600/600 [00:17<00:00, 34.71it/s]
100%|██████████| 100/100 [00:02<00:00, 40.07it/s]
  1%|          | 4/600 [00:00<00:15, 37.62it/s]

epoch 159 loss 0.13008984204381704 prediction 0.9875000065565109
I_ZX_bound 12.456722497940063 I_ZY_bound -0.11763311870279722


100%|██████████| 600/600 [00:17<00:00, 34.20it/s]
100%|██████████| 100/100 [00:02<00:00, 35.27it/s]
  0%|          | 3/600 [00:00<00:21, 28.11it/s]

epoch 160 loss 0.13180014761164785 prediction 0.9875000071525574
I_ZX_bound 12.568140335083008 I_ZY_bound -0.11923200711316895


100%|██████████| 600/600 [00:20<00:00, 28.65it/s]
100%|██████████| 100/100 [00:02<00:00, 35.82it/s]
  1%|          | 4/600 [00:00<00:17, 34.35it/s]

epoch 161 loss 0.1322067444678396 prediction 0.9870000064373017
I_ZX_bound 12.39418743133545 I_ZY_bound -0.11981255640625023


100%|██████████| 600/600 [00:21<00:00, 27.43it/s]
100%|██████████| 100/100 [00:03<00:00, 25.01it/s]
  0%|          | 2/600 [00:00<00:31, 18.81it/s]

epoch 162 loss 0.13130503050051628 prediction 0.9870000070333481
I_ZX_bound 12.444798603057862 I_ZY_bound -0.11886023125261999


100%|██████████| 600/600 [00:25<00:00, 23.73it/s]
100%|██████████| 100/100 [00:03<00:00, 26.90it/s]
  0%|          | 2/600 [00:00<00:40, 14.74it/s]

epoch 163 loss 0.13238344434648752 prediction 0.9874000072479248
I_ZX_bound 12.328780441284179 I_ZY_bound -0.12005466327304021


100%|██████████| 600/600 [00:24<00:00, 24.70it/s]
100%|██████████| 100/100 [00:03<00:00, 27.88it/s]
  0%|          | 3/600 [00:00<00:24, 24.64it/s]

epoch 164 loss 0.13273409653455018 prediction 0.9872000062465668
I_ZX_bound 12.4380837059021 I_ZY_bound -0.12029601092392113


100%|██████████| 600/600 [00:24<00:00, 24.84it/s]
100%|██████████| 100/100 [00:03<00:00, 32.69it/s]
  0%|          | 3/600 [00:00<00:24, 23.91it/s]

epoch 165 loss 0.1319392148591578 prediction 0.9871000063419342
I_ZX_bound 12.558723258972169 I_ZY_bound -0.11938049046089873


100%|██████████| 600/600 [00:25<00:00, 23.31it/s]
100%|██████████| 100/100 [00:03<00:00, 28.94it/s]
  0%|          | 3/600 [00:00<00:24, 24.62it/s]

epoch 166 loss 0.1343221411202103 prediction 0.9865000069141387
I_ZX_bound 12.412073011398315 I_ZY_bound -0.12191006759006996


100%|██████████| 600/600 [00:25<00:00, 23.62it/s]
100%|██████████| 100/100 [00:03<00:00, 30.23it/s]
  1%|          | 4/600 [00:00<00:17, 33.89it/s]

epoch 167 loss 0.13208948818035424 prediction 0.986700006723404
I_ZX_bound 12.524575920104981 I_ZY_bound -0.11956491093267686


100%|██████████| 600/600 [00:21<00:00, 28.20it/s]
100%|██████████| 100/100 [00:02<00:00, 35.62it/s]
  1%|          | 5/600 [00:00<00:14, 40.45it/s]

epoch 168 loss 0.1306948149483651 prediction 0.9872000068426132
I_ZX_bound 12.291062755584717 I_ZY_bound -0.11840375224826857


100%|██████████| 600/600 [00:20<00:00, 29.59it/s]
100%|██████████| 100/100 [00:02<00:00, 37.20it/s]
  1%|          | 4/600 [00:00<00:17, 34.05it/s]

epoch 169 loss 0.1324496597610414 prediction 0.9872000068426132
I_ZX_bound 12.362273092269897 I_ZY_bound -0.12008738618576899


100%|██████████| 600/600 [00:20<00:00, 28.60it/s]
100%|██████████| 100/100 [00:02<00:00, 34.73it/s]
  0%|          | 3/600 [00:00<00:21, 27.52it/s]

epoch 170 loss 0.13008811886422336 prediction 0.9880000072717666
I_ZX_bound 12.420158033370972 I_ZY_bound -0.11766795988485683


100%|██████████| 600/600 [00:20<00:00, 29.43it/s]
100%|██████████| 100/100 [00:02<00:00, 35.65it/s]
  0%|          | 3/600 [00:00<00:22, 26.68it/s]

epoch 171 loss 0.13201219560578464 prediction 0.9871000057458877
I_ZX_bound 12.328085212707519 I_ZY_bound -0.11968411037232726


100%|██████████| 600/600 [00:20<00:00, 28.97it/s]
100%|██████████| 100/100 [00:02<00:00, 39.04it/s]
  0%|          | 3/600 [00:00<00:22, 26.63it/s]

epoch 172 loss 0.1314890881255269 prediction 0.9867000061273575
I_ZX_bound 12.326765565872192 I_ZY_bound -0.11916232264309656


100%|██████████| 600/600 [00:20<00:00, 29.32it/s]
100%|██████████| 100/100 [00:02<00:00, 37.33it/s]
  0%|          | 3/600 [00:00<00:20, 29.70it/s]

epoch 173 loss 0.13253592316061258 prediction 0.9870000070333481
I_ZX_bound 12.43944727897644 I_ZY_bound -0.12009647543658503


100%|██████████| 600/600 [00:20<00:00, 28.62it/s]
100%|██████████| 100/100 [00:02<00:00, 34.36it/s]
  0%|          | 3/600 [00:00<00:24, 24.61it/s]

epoch 174 loss 0.1320995928440243 prediction 0.9873000073432923
I_ZX_bound 12.397684516906738 I_ZY_bound -0.11970190719584935


100%|██████████| 600/600 [00:20<00:00, 28.60it/s]
100%|██████████| 100/100 [00:02<00:00, 37.75it/s]
  0%|          | 3/600 [00:00<00:20, 28.81it/s]

epoch 175 loss 0.13249867300502957 prediction 0.9870000070333481
I_ZX_bound 12.482021007537842 I_ZY_bound -0.12001665107381995


100%|██████████| 600/600 [00:21<00:00, 28.48it/s]
100%|██████████| 100/100 [00:02<00:00, 36.21it/s]
  0%|          | 3/600 [00:00<00:19, 29.90it/s]

epoch 176 loss 0.1351298137754202 prediction 0.9866000062227249
I_ZX_bound 12.433639450073242 I_ZY_bound -0.12269617264450063


100%|██████████| 600/600 [00:19<00:00, 30.70it/s]
100%|██████████| 100/100 [00:03<00:00, 31.90it/s]
  1%|          | 4/600 [00:00<00:17, 33.97it/s]

epoch 177 loss 0.13278084944933652 prediction 0.9875000071525574
I_ZX_bound 12.394059619903565 I_ZY_bound -0.12038678842945956


100%|██████████| 600/600 [00:21<00:00, 28.19it/s]
100%|██████████| 100/100 [00:02<00:00, 34.69it/s]
  1%|          | 4/600 [00:00<00:17, 33.75it/s]

epoch 178 loss 0.1335048918426037 prediction 0.9869000065326691
I_ZX_bound 12.388135557174683 I_ZY_bound -0.12111675519612618


100%|██████████| 600/600 [00:20<00:00, 29.70it/s]
100%|██████████| 100/100 [00:02<00:00, 41.76it/s]
  1%|          | 4/600 [00:00<00:17, 33.67it/s]

epoch 179 loss 0.1335658698156476 prediction 0.9874000072479248
I_ZX_bound 12.414421701431275 I_ZY_bound -0.12115144692652394


100%|██████████| 600/600 [00:18<00:00, 32.28it/s]
100%|██████████| 100/100 [00:02<00:00, 33.91it/s]
  0%|          | 3/600 [00:00<00:21, 27.61it/s]

epoch 180 loss 0.1330539914406836 prediction 0.9869000071287155
I_ZX_bound 12.388353700637817 I_ZY_bound -0.12066563731408678


100%|██████████| 600/600 [00:21<00:00, 28.26it/s]
100%|██████████| 100/100 [00:03<00:00, 31.37it/s]
  0%|          | 3/600 [00:00<00:22, 26.27it/s]

epoch 181 loss 0.13119870169088244 prediction 0.9864000076055527
I_ZX_bound 12.491575813293457 I_ZY_bound -0.11870712567644659


100%|██████████| 600/600 [00:24<00:00, 24.52it/s]
100%|██████████| 100/100 [00:03<00:00, 25.75it/s]
  0%|          | 3/600 [00:00<00:20, 29.20it/s]

epoch 182 loss 0.13343329614028335 prediction 0.9873000061511994
I_ZX_bound 12.426725254058837 I_ZY_bound -0.1210065698454855


100%|██████████| 600/600 [00:22<00:00, 26.79it/s]
100%|██████████| 100/100 [00:02<00:00, 37.10it/s]
  0%|          | 3/600 [00:00<00:22, 26.71it/s]

epoch 183 loss 0.13268560657277703 prediction 0.9871000069379806
I_ZX_bound 12.258878393173218 I_ZY_bound -0.1204267274425365


100%|██████████| 600/600 [00:20<00:00, 29.06it/s]
100%|██████████| 100/100 [00:02<00:00, 37.06it/s]
  0%|          | 3/600 [00:00<00:26, 22.57it/s]

epoch 184 loss 0.13339881935156883 prediction 0.9871000063419342
I_ZX_bound 12.269152927398682 I_ZY_bound -0.121129665261833


100%|██████████| 600/600 [00:21<00:00, 28.37it/s]
100%|██████████| 100/100 [00:02<00:00, 38.39it/s]
  0%|          | 3/600 [00:00<00:20, 29.45it/s]

epoch 185 loss 0.13456007158383726 prediction 0.9874000066518783
I_ZX_bound 12.353243141174316 I_ZY_bound -0.12220682796672917


100%|██████████| 600/600 [00:20<00:00, 29.51it/s]
100%|██████████| 100/100 [00:02<00:00, 35.84it/s]
  0%|          | 3/600 [00:00<00:20, 29.68it/s]

epoch 186 loss 0.13171088722534477 prediction 0.9866000074148178
I_ZX_bound 12.373409156799317 I_ZY_bound -0.11933747742848937


100%|██████████| 600/600 [00:20<00:00, 28.85it/s]
100%|██████████| 100/100 [00:02<00:00, 33.44it/s]
  0%|          | 3/600 [00:00<00:23, 25.89it/s]

epoch 187 loss 0.13408181123435498 prediction 0.9866000068187714
I_ZX_bound 12.406891345977783 I_ZY_bound -0.12167491908883676


100%|██████████| 600/600 [00:21<00:00, 27.78it/s]
100%|██████████| 100/100 [00:02<00:00, 33.41it/s]
  0%|          | 3/600 [00:00<00:21, 28.30it/s]

epoch 188 loss 0.13465762057341635 prediction 0.9870000064373017
I_ZX_bound 12.41818263053894 I_ZY_bound -0.12223943727090955


100%|██████████| 600/600 [00:19<00:00, 30.36it/s]
100%|██████████| 100/100 [00:02<00:00, 34.10it/s]
  1%|          | 4/600 [00:00<00:18, 32.22it/s]

epoch 189 loss 0.1307276938855648 prediction 0.98680000603199
I_ZX_bound 12.3459951877594 I_ZY_bound -0.11838169888942503


100%|██████████| 600/600 [00:20<00:00, 29.11it/s]
100%|██████████| 100/100 [00:02<00:00, 34.48it/s]
  0%|          | 3/600 [00:00<00:26, 22.18it/s]

epoch 190 loss 0.13099086970090867 prediction 0.9870000070333481
I_ZX_bound 12.375435991287231 I_ZY_bound -0.11861543348059059


100%|██████████| 600/600 [00:22<00:00, 26.52it/s]
100%|██████████| 100/100 [00:03<00:00, 30.75it/s]
  0%|          | 2/600 [00:00<00:36, 16.41it/s]

epoch 191 loss 0.13188214906491338 prediction 0.9873000073432923
I_ZX_bound 12.282616844177246 I_ZY_bound -0.11959953256417066


100%|██████████| 600/600 [00:23<00:00, 25.79it/s]
100%|██████████| 100/100 [00:04<00:00, 24.63it/s]
  0%|          | 2/600 [00:00<00:45, 13.25it/s]

epoch 192 loss 0.1332544678542763 prediction 0.9868000072240829
I_ZX_bound 12.3098432636261 I_ZY_bound -0.12094462414039299


100%|██████████| 600/600 [00:25<00:00, 23.96it/s]
100%|██████████| 100/100 [00:03<00:00, 29.56it/s]
  0%|          | 3/600 [00:00<00:23, 25.90it/s]

epoch 193 loss 0.13375953430309892 prediction 0.9869000065326691
I_ZX_bound 12.329252986907958 I_ZY_bound -0.12143028109916486


100%|██████████| 600/600 [00:24<00:00, 24.27it/s]
100%|██████████| 100/100 [00:03<00:00, 30.85it/s]
  1%|          | 4/600 [00:00<00:19, 30.73it/s]

epoch 194 loss 0.13389398456551135 prediction 0.9870000058412551
I_ZX_bound 12.326471071243287 I_ZY_bound -0.12156751267029904


100%|██████████| 600/600 [00:24<00:00, 24.23it/s]
100%|██████████| 100/100 [00:03<00:00, 29.03it/s]
  0%|          | 2/600 [00:00<00:41, 14.42it/s]

epoch 195 loss 0.13496678703464568 prediction 0.9867000073194504
I_ZX_bound 12.42399959564209 I_ZY_bound -0.12254278624022845


100%|██████████| 600/600 [00:24<00:00, 24.83it/s]
100%|██████████| 100/100 [00:03<00:00, 27.92it/s]
  0%|          | 2/600 [00:00<00:39, 15.28it/s]

epoch 196 loss 0.13485620527528228 prediction 0.9869000053405762
I_ZX_bound 12.436168403625489 I_ZY_bound -0.1224200362560805


100%|██████████| 600/600 [00:25<00:00, 23.41it/s]
100%|██████████| 100/100 [00:02<00:00, 39.20it/s]
  1%|          | 4/600 [00:00<00:19, 30.32it/s]

epoch 197 loss 0.13256442662328483 prediction 0.9873000061511994
I_ZX_bound 12.387525262832641 I_ZY_bound -0.1201769008784322


100%|██████████| 600/600 [00:22<00:00, 26.64it/s]
100%|██████████| 100/100 [00:02<00:00, 37.83it/s]
  0%|          | 2/600 [00:00<00:32, 18.15it/s]

epoch 198 loss 0.1305179003905505 prediction 0.9872000068426132
I_ZX_bound 12.354644947052002 I_ZY_bound -0.11816325525054708


100%|██████████| 600/600 [00:18<00:00, 32.65it/s]
100%|██████████| 100/100 [00:02<00:00, 42.78it/s]

epoch 199 loss 0.13265600335784256 prediction 0.9874000066518783
I_ZX_bound 12.436276140213012 I_ZY_bound -0.12021972775517498





In [38]:
model_avg = VIB(X_dim=784, y_dim=10, beta = beta, num_samples=samples_amount).cuda()

Use the averaged model:

In [50]:
for name, param in model_avg.named_parameters():
    if param.requires_grad:
        param.data = ema.shadow[name]

In [61]:
loss_by_epoch_test = []
accuracy_by_epoch_test = []
I_ZX_bound_by_epoch_test = []
I_ZY_bound_by_epoch_test = []


for x_batch, y_batch in tqdm.tqdm(test_loader):
    x_batch = x_batch.cuda()
    y_batch = y_batch.cuda()

    loss, I_ZY_bound, I_ZX_bound = model.batch_loss(samples_amount, x_batch, y_batch)

    logits = model.forward(x_batch)
    prediction = torch.max(logits, dim=1)[1]
    accuracy = torch.mean((prediction == y_batch).float())


    I_ZX_bound_by_epoch_test.append(I_ZX_bound.item())
    I_ZY_bound_by_epoch_test.append(I_ZY_bound.item())

    loss_by_epoch_test.append(loss.item())
    accuracy_by_epoch_test.append(accuracy.item())
    
print('loss', np.mean(loss_by_epoch_test), 
      'prediction', np.mean(accuracy_by_epoch_test))

print('I_ZX_bound', np.mean(I_ZX_bound_by_epoch_test), 
      'I_ZY_bound', np.mean(I_ZY_bound_by_epoch_test))

100%|██████████| 100/100 [00:02<00:00, 46.04it/s]

loss 0.13123689855448903 prediction 0.9873000073432923
I_ZX_bound 12.436276140213012 I_ZY_bound -0.11880062255891971



