In [1]:
#!nvidia-smi
#using a GeForce GTX1080 Ti for reproducibility for all timing experiments

In [2]:
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
import numpy as np
from sklearn.metrics import roc_auc_score
import scipy
from utils.LB_utils import * 
from utils.load_not_MNIST import notMNIST
import os
import time
import matplotlib.pyplot as plt
from laplace import Laplace

s = 1
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed(s)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cuda_status = torch.cuda.is_available()
print("device: ", device)
print("cuda status: ", cuda_status)

device:  cuda
cuda status:  True


In [4]:
### define network
class ConvNet(nn.Module):
    
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        
        self.net = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(16, 32, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Flatten(),
            torch.nn.Linear(4 * 4 * 32, num_classes)
        )
    def forward(self, x):
        out = self.net(x)
        return out

In [5]:
BATCH_SIZE_TRAIN_MNIST = 128
BATCH_SIZE_TEST_MNIST = 128
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [6]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=True,
        download=True,
        transform=MNIST_transform)

MNIST_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform)

MNIST_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [7]:
mnist_model = ConvNet().to(device)
loss_function = torch.nn.CrossEntropyLoss()

mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer_s{}.pth".format(s)

In [8]:
#Training routine

def train(model, train_loader, optimizer, max_iter, path, verbose=True):
    max_len = len(train_loader)

    for iter in range(max_iter):
        for batch_idx, (x, y) in enumerate(train_loader):
            
            x, y = x.to(device), y.to(device)
            
            output = model(x)

            accuracy = get_accuracy(output, y)

            loss = loss_function(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if verbose and batch_idx % 50 == 0:
                print(
                    "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                    "Minibatch Loss %.3f  " % (loss) +
                    "Accuracy %.0f" % (accuracy * 100) + "%"
                )

    print("saving model at: {}".format(path))
    torch.save(mnist_model.state_dict(), path)

In [9]:
#train(mnist_model, MNIST_train_loader, mnist_train_optimizer, MAX_ITER_MNIST, MNIST_PATH, verbose=True)

In [10]:
#predict in distribution
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer_s{}.pth".format(s)
#MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer.pth"

mnist_model = ConvNet().to(device)
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

max_len = len(MNIST_test_loader)
for batch_idx, (x, y) in enumerate(MNIST_test_loader):

    x, y = x.to(device), y.to(device)
    output = mnist_model(x)

    accuracy = get_accuracy(output, y)
    if batch_idx % 10 == 0:
        print(
            "Batch {}/{} \t".format(batch_idx, max_len) + 
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )
    acc.append(accuracy)

avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: pretrained_weights/MNIST_pretrained_10_classes_last_layer_s1.pth
Batch 0/79 	Accuracy 100%
Batch 10/79 	Accuracy 96%
Batch 20/79 	Accuracy 98%
Batch 30/79 	Accuracy 98%
Batch 40/79 	Accuracy 100%
Batch 50/79 	Accuracy 99%
Batch 60/79 	Accuracy 100%
Batch 70/79 	Accuracy 98%
overall test accuracy on MNIST: 98.84 %


In [11]:
BATCH_SIZE_TEST_FMNIST = 128
BATCH_SIZE_TEST_KMNIST = 128

In [12]:
FMNIST_test = torchvision.datasets.FashionMNIST(
        '~/data/fmnist', train=False, download=True,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [13]:
KMNIST_test = torchvision.datasets.KMNIST(
        '~/data/kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=BATCH_SIZE_TEST_KMNIST, shuffle=False)

In [14]:
#root = os.path.abspath('~/data')
root = os.path.expanduser('~/data')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
notMNIST_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=BATCH_SIZE_TEST_KMNIST,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


# MAP estimate

In [15]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [16]:
MNIST_test_in_MAP = predict_MAP(mnist_model, MNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_fmnist_MAP = predict_MAP(mnist_model, FMNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_notMNIST_MAP = predict_MAP(mnist_model, notMNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_KMNIST_MAP = predict_MAP(mnist_model, KMNIST_test_loader, device=device).cpu().numpy()

In [17]:
acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP = get_in_dist_values(MNIST_test_in_MAP, targets)
acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_fmnist_MAP, targets_FMNIST)
acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_notMNIST_MAP, targets_notMNIST)
acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_KMNIST_MAP, targets_KMNIST)

In [18]:
print_in_dist_values(acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP, 'MNIST', 'MAP')
print_out_dist_values(acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP, 'MNIST', test='FMNIST', method='MAP')
print_out_dist_values(acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP, 'MNIST', test='notMNIST', method='MAP')
print_out_dist_values(acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP, 'MNIST', test='KMNIST', method='MAP')

[In, MAP, MNIST] Accuracy: 0.988; average entropy: 0.047;     MMC: 0.986; Prob @ correct: 0.100
[Out-FMNIST, MAP, MNIST] Accuracy: 0.073; Average entropy: 0.972;    MMC: 0.663; AUROC: 0.977; Prob @ correct: 0.100
[Out-notMNIST, MAP, MNIST] Accuracy: 0.148; Average entropy: 0.620;    MMC: 0.774; AUROC: 0.914; Prob @ correct: 0.100
[Out-KMNIST, MAP, MNIST] Accuracy: 0.093; Average entropy: 0.750;    MMC: 0.723; AUROC: 0.959; Prob @ correct: 0.100


In [19]:
num_samples = 100

# Diag Hessian Sampling estimate

In [20]:
la_diag = Laplace(mnist_model, 'classification', 
                     subset_of_weights='last_layer', 
                     hessian_structure='diag',
                     prior_precision=1e-0) # 5e-4 # Choose prior precision according to weight decay
la_diag.fit(MNIST_train_loader)

In [21]:
MNIST_test_in_D = predict_samples(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_D = predict_samples(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_D = predict_samples(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_D = predict_samples(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.1628369900000006
time:  0.9792774170000023
time:  0.4947420930000028
time:  0.9655261910000021


In [22]:
# compute average log-likelihood for Diag
print(torch.distributions.Categorical(torch.tensor(MNIST_test_in_D)).log_prob(torch.tensor(targets)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_D)).log_prob(torch.tensor(targets_FMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_D)).log_prob(torch.tensor(targets_notMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_D)).log_prob(torch.tensor(targets_KMNIST)).mean())

tensor(-0.0605)
tensor(-3.9607)
tensor(-4.3349)
tensor(-4.9334)


In [23]:
import utils.scoring as scoring

In [24]:
#compute the Expected confidence estimate
print(scoring.expected_calibration_error(targets, MNIST_test_in_D))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_D))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_D))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_D))

0.03493014141414141
0.46846589898989915
0.409833633277876
0.43425194949494955


In [25]:
acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D = get_in_dist_values(MNIST_test_in_D, targets)
acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_FMNIST_D, targets_FMNIST)
acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_notMNIST_D, targets_notMNIST)
acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_KMNIST_D, targets_KMNIST)

In [26]:
print_in_dist_values(acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D, 'MNIST', 'Diag')
print_out_dist_values(acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D, 'MNIST', test='fmnist', method='Diag')
print_out_dist_values(acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D, 'MNIST', test='notMNIST', method='Diag')
print_out_dist_values(acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D, 'MNIST', test='KMNIST', method='Diag')

[In, Diag, MNIST] Accuracy: 0.989; average entropy: 0.160;     MMC: 0.956; Prob @ correct: 0.100
[Out-fmnist, Diag, MNIST] Accuracy: 0.073; Average entropy: 1.313;    MMC: 0.541; AUROC: 0.975; Prob @ correct: 0.100
[Out-notMNIST, Diag, MNIST] Accuracy: 0.149; Average entropy: 1.212;    MMC: 0.559; AUROC: 0.963; Prob @ correct: 0.100
[Out-KMNIST, Diag, MNIST] Accuracy: 0.095; Average entropy: 1.272;    MMC: 0.529; AUROC: 0.973; Prob @ correct: 0.100


# KFAC Laplace Approximation (sampling)

In [27]:
la_kron = Laplace(mnist_model, 'classification', 
                     subset_of_weights='last_layer', 
                     hessian_structure='kron',
                     prior_precision=5e-0) # 5e-4 # Choose prior precision according to weight decay
la_kron.fit(MNIST_train_loader)

The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  /opt/conda/conda-bld/pytorch_1639180549130/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2499.)
  L, W = torch.symeig(M, eigenvectors=True)


In [28]:
MNIST_test_in_KFAC = predict_samples(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_KFAC = predict_samples(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_KFAC = predict_samples(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_KFAC = predict_samples(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.2494627059999992
time:  1.2474731460000008
time:  1.0096681889999992
time:  1.2346936679999985


In [29]:
# compute average log-likelihood for KFAC
print(torch.distributions.Categorical(torch.tensor(MNIST_test_in_KFAC)).log_prob(torch.tensor(targets)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_KFAC)).log_prob(torch.tensor(targets_FMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_KFAC)).log_prob(torch.tensor(targets_notMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_KFAC)).log_prob(torch.tensor(targets_KMNIST)).mean())

tensor(-0.0487)
tensor(-3.7331)
tensor(-4.3672)
tensor(-4.8779)


In [30]:
# compute ECE for KFAC
print(scoring.expected_calibration_error(targets, MNIST_test_in_KFAC))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_KFAC))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_KFAC))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_KFAC))

0.024422959595959588
0.42929906060606066
0.40477670855100895
0.4167330707070709


In [31]:
acc_in_KFAC, prob_correct_in_KFAC, ent_in_KFAC, MMC_in_KFAC = get_in_dist_values(MNIST_test_in_KFAC, targets)
acc_out_FMNIST_KFAC, prob_correct_out_FMNIST_KFAC, ent_out_FMNIST_KFAC, MMC_out_FMNIST_KFAC, auroc_out_FMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_FMNIST_KFAC, targets_FMNIST)
acc_out_notMNIST_KFAC, prob_correct_out_notMNIST_KFAC, ent_out_notMNIST_KFAC, MMC_out_notMNIST_KFAC, auroc_out_notMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_notMNIST_KFAC, targets_notMNIST)
acc_out_KMNIST_KFAC, prob_correct_out_KMNIST_KFAC, ent_out_KMNIST_KFAC, MMC_out_KMNIST_KFAC, auroc_out_KMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_KMNIST_KFAC, targets_KMNIST)

In [32]:
print_in_dist_values(acc_in_KFAC, prob_correct_in_KFAC, ent_in_KFAC, MMC_in_KFAC, 'MNIST', 'KFAC')
print_out_dist_values(acc_out_FMNIST_KFAC, prob_correct_out_FMNIST_KFAC, ent_out_FMNIST_KFAC, MMC_out_FMNIST_KFAC, auroc_out_FMNIST_KFAC, 'MNIST', test='fmnist', method='KFAC')
print_out_dist_values(acc_out_notMNIST_KFAC, prob_correct_out_notMNIST_KFAC, ent_out_notMNIST_KFAC, MMC_out_notMNIST_KFAC, auroc_out_notMNIST_KFAC, 'MNIST', test='notMNIST', method='KFAC')
print_out_dist_values(acc_out_KMNIST_KFAC, prob_correct_out_KMNIST_KFAC, ent_out_KMNIST_KFAC, MMC_out_KMNIST_KFAC, auroc_out_KMNIST_KFAC, 'MNIST', test='KMNIST', method='KFAC')

[In, KFAC, MNIST] Accuracy: 0.988; average entropy: 0.111;     MMC: 0.967; Prob @ correct: 0.100
[Out-fmnist, KFAC, MNIST] Accuracy: 0.072; Average entropy: 1.402;    MMC: 0.501; AUROC: 0.985; Prob @ correct: 0.100
[Out-notMNIST, KFAC, MNIST] Accuracy: 0.148; Average entropy: 1.228;    MMC: 0.551; AUROC: 0.969; Prob @ correct: 0.100
[Out-KMNIST, KFAC, MNIST] Accuracy: 0.094; Average entropy: 1.315;    MMC: 0.511; AUROC: 0.982; Prob @ correct: 0.100


# Laplace Bridge (Diagonal)

In [33]:
MNIST_test_in_LB_D = predict_LB(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_LB_D = predict_LB(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_LB_D = predict_LB(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_LB_D = predict_LB(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

step0, f_mu[0]:  tensor([ -3.6094,  -2.8937,   1.1239,   1.7301,  -9.0167,  -3.9072, -17.8319,
         13.2881,  -5.2003,   0.8784], device='cuda:0')
step0, f_var[0]:  tensor([[ 9.1984,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 19.1064,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  3.7570,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  3.2656,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  8.3827,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.0570,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 21.9312,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -7.5057, -11.0894,  -6.9948,  -1.1693,   1.8382,  -0.7919, -13.6171,
          1.7905,  -1.8001,  10.2949], device='cuda:0')
step0, f_var[0]:  tensor([[ 8.6273,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 22.3177,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  4.7154,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  3.1492,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  3.1418,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.8023,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 10.9577,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -2.9080, -11.5330,  -6.2933,  -1.2008,   0.4824,  -4.2224, -13.5404,
         -0.4246,  -4.2598,  11.8633], device='cuda:0')
step0, f_var[0]:  tensor([[10.3306,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 43.7253,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  6.6535,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  5.5132,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  4.3741,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.5211,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 12.8647,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -6.8779,  -3.1840,   6.9303,   2.6190,  -6.5269, -11.0182, -20.9386,
         15.0805,  -5.3011,  -5.8907], device='cuda:0')
step0, f_var[0]:  tensor([[13.9824,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 24.1308,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  3.7583,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  4.9627,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  9.5200,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 11.0143,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 24.7522,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ 13.0570, -15.2212,  -5.3315,  -9.7353, -11.3007,   1.5989,   2.8546,
         -5.1484,  -3.9605,  -2.5558], device='cuda:0')
step0, f_var[0]:  tensor([[ 9.2979,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 90.9389,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000, 11.0750,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, 10.4740,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, 29.4905,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  8.2674,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 13.9678,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -6.4632, -11.9964,  -2.7995,  12.3160, -12.9172,  -0.6458, -12.5930,
         -5.7308,  -1.9132,  -1.0479], device='cuda:0')
step0, f_var[0]:  tensor([[12.7646,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 40.7592,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  6.5587,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  4.6353,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, 10.4913,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.4143,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 15.2625,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -4.3906,  -0.5647,   0.2772,   2.9066,  -8.6856,  -6.7296, -21.3044,
          9.9737,   1.2952,   0.1870], device='cuda:0')
step0, f_var[0]:  tensor([[ 7.7132,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  7.7809,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  2.5320,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  2.7889,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  5.1250,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.9528,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 14.9752,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([-9.4226, -5.4002, -5.1446, -1.1133, -4.4139, -3.2032, -4.0904, -7.1229,
         9.0483, -4.0936], device='cuda:0')
step0, f_var[0]:  tensor([[14.2476,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 13.1167,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  5.0924,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  4.0849,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  4.4622,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.5069,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  8.1407,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.683

step0, f_mu[0]:  tensor([-0.5414,  2.4982,  5.6497, -3.9596, -3.2051, -2.3778,  2.6156, -4.8353,
        -3.4431, -8.3073], device='cuda:0')
step0, f_var[0]:  tensor([[3.8732, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 5.9758, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 1.4662, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 1.9571, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 2.4289, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.5984, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 4.2627, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.9983, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

step0, f_mu[0]:  tensor([-0.9382, -2.6360,  3.8689, -4.3189,  1.7096, -5.7209,  0.0865, -2.0182,
        -3.1189, -7.4798], device='cuda:0')
step0, f_var[0]:  tensor([[ 6.6884,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 16.1162,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  3.0243,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  3.5364,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  2.9954,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.2245,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.4157,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.697

step0, f_mu[0]:  tensor([-1.2731, -1.1971,  2.5475, -1.2031,  1.6786, -4.9856, -1.2786, -0.9200,
        -0.9669, -3.7939], device='cuda:0')
step0, f_var[0]:  tensor([[1.9952, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 5.8097, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.8548, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 1.0526, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9362, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.6894, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.8958, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0640, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

step0, f_mu[0]:  tensor([ 3.7312, -5.8886,  1.3341, -4.7002, -0.4306, -0.9747, -1.8511, -2.4067,
        -1.8593, -2.3314], device='cuda:0')
step0, f_var[0]:  tensor([[ 3.5641,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 11.7142,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  2.0854,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  2.2958,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  2.2217,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.5702,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.0896,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.502

step0, f_mu[0]:  tensor([ 1.8873, -4.4262, -0.1514, -3.7928, -1.4979, -3.1665,  2.5059, -5.6268,
         3.1870, -1.9999], device='cuda:0')
step0, f_var[0]:  tensor([[ 2.1662,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 12.5507,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.8211,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.8872,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.8327,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.5218,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.0048,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.482

step0, f_mu[0]:  tensor([ 0.9880, -4.8211, -3.0523, -2.3960,  0.9615, -2.1299,  0.4115, -1.2594,
         1.7008, -0.2908], device='cuda:0')
step0, f_var[0]:  tensor([[1.2306, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 7.3702, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 1.1184, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 1.1096, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9016, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0448, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.2711, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.2471, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

step0, f_mu[0]:  tensor([-2.9745,  0.9333, -2.2593, -2.8744,  0.5681, -1.2326, -2.7290, -1.3555,
         1.9594, -3.4845], device='cuda:0')
step0, f_var[0]:  tensor([[2.0770, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 1.5294, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.9073, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.9260, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.8605, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0209, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.6224, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0794, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

step0, f_mu[0]:  tensor([ 1.8767, -3.8146,  3.6277, -3.0152, -8.3996, -2.9255,  1.8478,  0.2345,
         1.2518, -3.0471], device='cuda:0')
step0, f_var[0]:  tensor([[ 7.9338,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 15.1925,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  3.3516,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  3.2767,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  3.1756,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.6347,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  7.8280,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.226

step0, f_mu[0]:  tensor([-0.3161,  1.1241,  4.4552, -2.3734, -2.8883, -1.8350, -0.9560, -5.2699,
        -0.5355, -7.4712], device='cuda:0')
step0, f_var[0]:  tensor([[3.0780, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 4.6790, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 1.1219, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 1.4665, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 2.1366, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.8509, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 4.3258, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.3924, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

step0, f_mu[0]:  tensor([ 0.8253, -3.5672,  1.4977,  2.7449, -8.0567, -0.3746, -6.9156,  1.1370,
        -4.2357, -0.8467], device='cuda:0')
step0, f_var[0]:  tensor([[ 2.5163,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 14.6274,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.1939,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.2852,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  3.8519,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.8051,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  7.5871,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.856

step0, f_mu[0]:  tensor([  2.9235,   1.8687,   4.0521,  -4.6138, -10.7372,   1.6388,  -4.7731,
        -11.3678,  -2.9958,  -3.0683], device='cuda:0')
step0, f_var[0]:  tensor([[11.5369,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 49.9274,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  5.2459,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  6.4539,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, 21.4694,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  5.7712,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 13.1942,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([-1.6839, -6.1704, -0.0551,  0.6824, -1.5555, -1.4691, -1.4595, -4.1280,
         1.5008, -0.9792], device='cuda:0')
step0, f_var[0]:  tensor([[ 3.8700,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 16.0336,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  2.1659,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.8278,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  3.5506,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.9651,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.6699,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.964

step0, f_mu[0]:  tensor([ -7.6060,   6.6851,  -5.7761,   4.2777,   3.4026,  -2.2205, -10.9503,
         -0.7301,  -5.1101,  -4.1047], device='cuda:0')
step0, f_var[0]:  tensor([[ 8.4319,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  9.2616,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  2.8632,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  2.1113,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  4.0704,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.5048,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 18.1038,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -6.0624,  -9.8615,   0.6624,  -4.1802, -12.2806,   2.9086,  -2.8105,
         -0.1123,   2.6432,  -1.1027], device='cuda:0')
step0, f_var[0]:  tensor([[15.1003,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 45.7181,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  5.8053,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  6.3649,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, 10.2832,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.0158,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 12.6698,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -2.9086,  -9.5438,  -7.7035, -10.5020,   2.6750,   2.4523,   5.8661,
         -1.1683,   2.9740,  -7.1775], device='cuda:0')
step0, f_var[0]:  tensor([[15.0628,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 58.7164,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  8.8683,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  7.7647,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  9.0219,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  5.4391,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  9.1423,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([-11.6415,  -3.3189,  -0.4398,  -5.9514,   9.5302,  -3.5312, -10.5293,
         -2.9876,  -2.6965,  -0.4369], device='cuda:0')
step0, f_var[0]:  tensor([[13.0884,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 33.8368,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  5.7229,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  5.6059,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  5.7768,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.4285,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 10.6200,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([-10.0379,   4.8380,  -5.9691,  -1.9520,  -1.5475,   0.4216,  -0.8145,
         -6.0389,   0.8775,  -7.4013], device='cuda:0')
step0, f_var[0]:  tensor([[16.3770,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  9.0680,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  5.4310,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  4.2781,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  3.8267,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.0944,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  7.3386,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([  0.0973,  -1.6363,   5.1905, -12.9470,  -4.7264,  -5.0242,   5.8247,
         -2.4916,  -0.6598,  -7.9636], device='cuda:0')
step0, f_var[0]:  tensor([[20.6347,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 44.2885,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  8.5379,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  9.2137,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  9.5986,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  9.3668,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 16.2881,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -0.0439,  -6.8772,  -3.1380,   4.2916, -11.6949,   4.0695,  -2.0283,
         -6.8331,   0.3951,   0.2598], device='cuda:0')
step0, f_var[0]:  tensor([[ 5.7993,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 28.0917,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  3.8117,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  2.7940,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  7.8293,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.9601,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  9.4794,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([-12.3636,   0.9306,   0.7348,   9.1763,  -7.6719,   0.1563, -19.4335,
         -3.5332,  -3.4975,   0.8652], device='cuda:0')
step0, f_var[0]:  tensor([[10.4669,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  8.7018,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  3.4456,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  2.8076,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  8.9634,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.1672,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 19.1650,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([ -1.0859,   0.1049,  -0.8562, -10.8087,   1.9027,  -6.0715,  -3.1705,
          2.8520,  -2.5605,  -6.0328], device='cuda:0')
step0, f_var[0]:  tensor([[ 9.3253,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 41.9320,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  6.8290,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  7.2269,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, 11.1040,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  8.3863,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 15.0602,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([-0.8280, -2.6134,  0.2171, -1.5644, -6.4498, -2.6862, -3.1101, -4.9353,
         5.0877, -3.7339], device='cuda:0')
step0, f_var[0]:  tensor([[4.6454, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 8.5892, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 2.3127, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 2.2332, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 4.2975, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.5347, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 5.5034, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 3.9254, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

step0, f_mu[0]:  tensor([-4.5932,  1.7606, -3.2526, -6.9594,  0.8613,  0.6563, -6.1546, -1.6668,
         1.5682, -3.8757], device='cuda:0')
step0, f_var[0]:  tensor([[6.2097, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 7.2077, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 3.1321, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 2.9994, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 3.1573, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.5974, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 4.9684, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 4.4768, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

step0, f_mu[0]:  tensor([-4.7593, -8.0671, -0.8799, -5.7087, -1.0580, -5.6500,  3.3550, -4.2667,
         2.0627, -4.9913], device='cuda:0')
step0, f_var[0]:  tensor([[13.2982,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 41.5799,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  6.2717,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  6.4315,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  7.1126,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.9516,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 10.4285,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  8.654

step0, f_mu[0]:  tensor([ 2.2914, -0.5672, -0.1310, -7.0403,  0.4212, -0.8954,  1.7621, -1.5900,
         0.1058, -4.6421], device='cuda:0')
step0, f_var[0]:  tensor([[ 2.4297,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 15.8350,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  2.1857,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  2.2434,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  2.5646,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.9150,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.2008,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.616

step0, f_mu[0]:  tensor([-3.3540, -4.5152, -2.5444, -3.8279,  3.1069, -4.1957, -5.1523, -0.0104,
        -0.5599, -3.0251], device='cuda:0')
step0, f_var[0]:  tensor([[ 7.7231,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 12.1979,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  3.5464,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  3.6092,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  2.5903,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  5.0948,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  6.4599,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.343

step0, f_mu[0]:  tensor([  4.6695, -12.3448,   5.8070,   0.7010, -14.7909,  -7.0551,  -1.2691,
         -4.9956,   4.4095,  -3.3876], device='cuda:0')
step0, f_var[0]:  tensor([[ 6.9713,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 27.4533,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  3.9389,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  4.4555,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  8.2284,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  5.3663,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 11.4618,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

step0, f_mu[0]:  tensor([-2.1541, -8.2685, -4.1841,  2.8451, -3.9501, -1.8986, -0.4488, -1.9184,
        -0.3712,  1.2918], device='cuda:0')
step0, f_var[0]:  tensor([[ 4.8074,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 21.2241,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  2.7811,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  2.5758,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  3.2657,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.8287,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  4.6244,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.183

step0, f_mu[0]:  tensor([-1.5137, -6.8763, -2.8725, -0.3809, -1.8610, -2.9026, -3.5073,  0.0340,
         0.7767, -0.3569], device='cuda:0')
step0, f_var[0]:  tensor([[ 4.7254,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, 10.8481,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  2.4781,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  2.3369,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  2.5971,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.7414,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  5.1281,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.472

In [34]:
# # compute average log-likelihood for LB
print(torch.distributions.Categorical(torch.tensor(MNIST_test_in_LB_D)).log_prob(torch.tensor(targets)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_LB_D)).log_prob(torch.tensor(targets_FMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_LB_D)).log_prob(torch.tensor(targets_notMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_LB_D)).log_prob(torch.tensor(targets_KMNIST)).mean())

tensor(-0.0409)
tensor(-3.6545)
tensor(-5.5852)
tensor(-5.5203)


In [35]:
#compute ECE for LB
print(scoring.expected_calibration_error(targets, MNIST_test_in_LB_D))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_LB_D))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_LB_D))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_LB_D))

0.011439090909090939
0.48319700000000004
0.6082117478998488
0.5931453232323233


In [36]:
acc_in_LB_D, prob_correct_in_LB_D, ent_in_LB_D, MMC_in_LB_D = get_in_dist_values(MNIST_test_in_LB_D, targets)
acc_out_FMNIST_LB_D, prob_correct_out_FMNIST_LB_D, ent_out_FMNIST_LB_D, MMC_out_FMNIST_LB_D, auroc_out_FMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_FMNIST_LB_D, targets_FMNIST)
acc_out_notMNIST_LB_D, prob_correct_out_notMNIST_LB_D, ent_out_notMNIST_LB_D, MMC_out_notMNIST_LB_D, auroc_out_notMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_notMNIST_LB_D, targets_notMNIST)
acc_out_KMNIST_LB_D, prob_correct_out_KMNIST_LB_D, ent_out_KMNIST_LB_D, MMC_out_KMNIST_LB_D, auroc_out_KMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_KMNIST_LB_D, targets_KMNIST)

In [37]:
print_in_dist_values(acc_in_LB_D, prob_correct_in_LB_D, ent_in_LB_D, MMC_in_LB_D, 'MNIST', 'LB_D')
print_out_dist_values(acc_out_FMNIST_LB_D, prob_correct_out_FMNIST_LB_D, ent_out_FMNIST_LB_D, MMC_out_FMNIST_LB_D, auroc_out_FMNIST_LB_D, 'MNIST', test='fmnist', method='LB_D')
print_out_dist_values(acc_out_notMNIST_LB_D, prob_correct_out_notMNIST_LB_D, ent_out_notMNIST_LB_D, MMC_out_notMNIST_LB_D, auroc_out_notMNIST_LB_D, 'MNIST', test='notMNIST', method='LB_D')
print_out_dist_values(acc_out_KMNIST_LB_D, prob_correct_out_KMNIST_LB_D, ent_out_KMNIST_LB_D, MMC_out_KMNIST_LB_D, auroc_out_KMNIST_LB_D, 'MNIST', test='KMNIST', method='LB_D')

[In, LB_D, MNIST] Accuracy: 0.987; average entropy: 0.061;     MMC: 0.982; Prob @ correct: 0.100
[Out-fmnist, LB_D, MNIST] Accuracy: 0.082; Average entropy: 1.324;    MMC: 0.565; AUROC: 0.982; Prob @ correct: 0.100
[Out-notMNIST, LB_D, MNIST] Accuracy: 0.152; Average entropy: 0.737;    MMC: 0.735; AUROC: 0.931; Prob @ correct: 0.100
[Out-KMNIST, LB_D, MNIST] Accuracy: 0.099; Average entropy: 0.866;    MMC: 0.692; AUROC: 0.960; Prob @ correct: 0.100


# KFAC Laplace Bridge

In [38]:
MNIST_test_in_LB_KFAC = predict_LB(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_LB_KFAC = predict_LB(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_LB_KFAC = predict_LB(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_LB_KFAC = predict_LB(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

step0, f_mu[0]:  tensor([ -3.6094,  -2.8937,   1.1239,   1.7301,  -9.0167,  -3.9072, -17.8319,
         13.2881,  -5.2003,   0.8784], device='cuda:0')
step0, f_var[0]:  tensor([[42.6037, 40.4586, 40.5774, 40.5005, 40.5371, 40.5741, 40.8008, 40.5139,
         40.6334, 40.5754],
        [40.4586, 42.3545, 40.7232, 40.5542, 40.6854, 40.5083, 40.5141, 40.6926,
         40.6774, 40.6066],
        [40.5774, 40.7232, 41.9556, 40.7366, 40.6178, 40.5562, 40.4891, 40.7880,
         40.7114, 40.6196],
        [40.5005, 40.5542, 40.7366, 41.9227, 40.5880, 40.8414, 40.4990, 40.7278,
         40.7166, 40.6885],
        [40.5371, 40.6854, 40.6178, 40.5880, 41.8064, 40.5738, 40.6068, 40.7205,
         40.6691, 40.9704],
        [40.5741, 40.5083, 40.5562, 40.8414, 40.5738, 42.0326, 40.7329, 40.5634,
         40.7450, 40.6472],
        [40.8008, 40.5141, 40.4891, 40.4990, 40.6068, 40.7329, 42.4786, 40.4548,
         40.6701, 40.5288],
        [40.5139, 40.6926, 40.7880, 40.7278, 40.7205, 40.5634, 40.45

step0, f_mu[0]:  tensor([ 8.4595, -6.3176, -0.2372, -6.8238, -0.2402, -8.3703,  2.6686, -7.6730,
        -4.2665, -0.5893], device='cuda:0')
step0, f_var[0]:  tensor([[54.3705, 48.9291, 49.2248, 49.0313, 49.1222, 49.2127, 49.7731, 49.0660,
         49.3633, 49.2192],
        [48.9291, 53.7535, 49.5854, 49.1618, 49.4913, 49.0497, 49.0675, 49.5080,
         49.4736, 49.2922],
        [49.2248, 49.5854, 52.7533, 49.6214, 49.3213, 49.1674, 49.0035, 49.7507,
         49.5599, 49.3246],
        [49.0313, 49.1618, 49.6214, 52.6686, 49.2457, 49.8865, 49.0263, 49.5987,
         49.5720, 49.5003],
        [49.1222, 49.4913, 49.3213, 49.2457, 52.3732, 49.2115, 49.2980, 49.5795,
         49.4520, 50.2181],
        [49.2127, 49.0497, 49.1674, 49.8865, 49.2115, 52.9466, 49.6105, 49.1856,
         49.6445, 49.3974],
        [49.7732, 49.0675, 49.0035, 49.0263, 49.2980, 49.6105, 54.0596, 48.9178,
         49.4554, 49.1005],
        [49.0660, 49.5080, 49.7507, 49.5987, 49.5795, 49.1856, 48.9178, 52.490

step0, f_var[0]:  tensor([[27.2891, 24.1199, 24.2923, 24.1794, 24.2324, 24.2852, 24.6119, 24.1997,
         24.3731, 24.2890],
        [24.1199, 26.9294, 24.5026, 24.2555, 24.4477, 24.1902, 24.2006, 24.4574,
         24.4374, 24.3315],
        [24.2924, 24.5026, 26.3461, 24.5236, 24.3485, 24.2588, 24.1633, 24.5990,
         24.4877, 24.3505],
        [24.1794, 24.2555, 24.5236, 26.2966, 24.3044, 24.6782, 24.1765, 24.5103,
         24.4948, 24.4529],
        [24.2324, 24.4477, 24.3485, 24.3044, 26.1244, 24.2845, 24.3350, 24.4991,
         24.4248, 24.8716],
        [24.2852, 24.1902, 24.2588, 24.6782, 24.2845, 26.4588, 24.5172, 24.2694,
         24.5371, 24.3929],
        [24.6119, 24.2006, 24.1633, 24.1765, 24.3350, 24.5172, 27.1078, 24.1133,
         24.4268, 24.2198],
        [24.1997, 24.4574, 24.5990, 24.5103, 24.4991, 24.2694, 24.1133, 26.1926,
         24.3841, 24.6475],
        [24.3731, 24.4374, 24.4877, 24.4948, 24.4248, 24.5371, 24.4268, 24.3841,
         25.8048, 24.5017],
 

step0, f_mu[0]:  tensor([-1.7800,  7.8760, -1.3957, -7.2499,  0.1324, -3.8599, -0.0693, -1.4312,
        -0.0469, -5.0897], device='cuda:0')
step0, f_var[0]:  tensor([[21.5345, 20.6934, 20.7418, 20.7108, 20.7258, 20.7412, 20.8343, 20.7160,
         20.7648, 20.7410],
        [20.6934, 21.4323, 20.8017, 20.7333, 20.7862, 20.7142, 20.7159, 20.7894,
         20.7825, 20.7548],
        [20.7418, 20.8017, 21.2712, 20.8065, 20.7591, 20.7341, 20.7061, 20.8272,
         20.7959, 20.7601],
        [20.7108, 20.7333, 20.8065, 21.2584, 20.7473, 20.8484, 20.7105, 20.8031,
         20.7983, 20.7872],
        [20.7259, 20.7862, 20.7591, 20.7473, 21.2125, 20.7412, 20.7537, 20.8005,
         20.7794, 20.8982],
        [20.7412, 20.7142, 20.7342, 20.8484, 20.7412, 21.3022, 20.8055, 20.7370,
         20.8095, 20.7704],
        [20.8343, 20.7159, 20.7061, 20.7105, 20.7537, 20.8055, 21.4833, 20.6921,
         20.7795, 20.7227],
        [20.7160, 20.7894, 20.8272, 20.8031, 20.8005, 20.7370, 20.6921, 21.230

step0, f_var[0]:  tensor([[61.9231, 57.7097, 57.9391, 57.7889, 57.8594, 57.9296, 58.3642, 57.8159,
         58.0465, 57.9347],
        [57.7097, 61.4446, 58.2187, 57.8901, 58.1457, 57.8032, 57.8170, 58.1586,
         58.1320, 57.9913],
        [57.9391, 58.2187, 60.6688, 58.2466, 58.0139, 57.8945, 57.7674, 58.3469,
         58.1989, 58.0165],
        [57.7889, 57.8901, 58.2466, 60.6031, 57.9552, 58.4523, 57.7851, 58.2290,
         58.2084, 58.1527],
        [57.8594, 58.1457, 58.0139, 57.9553, 60.3742, 57.9287, 57.9958, 58.2142,
         58.1153, 58.7093],
        [57.9296, 57.8032, 57.8945, 58.4523, 57.9287, 60.8188, 58.2381, 57.9087,
         58.2645, 58.0729],
        [58.3642, 57.8170, 57.7674, 57.7851, 57.9958, 58.2381, 61.6819, 57.7009,
         58.1179, 57.8426],
        [57.8159, 58.1586, 58.3469, 58.2290, 58.2142, 57.9087, 57.7009, 60.4649,
         58.0612, 58.4114],
        [58.0464, 58.1320, 58.1989, 58.2084, 58.1153, 58.2645, 58.1179, 58.0612,
         59.9492, 58.2175],
 

step0, f_var[0]:  tensor([[10.3158,  9.6562,  9.6930,  9.6691,  9.6804,  9.6918,  9.7619,  9.6733,
          9.7102,  9.6923],
        [ 9.6562, 10.2388,  9.7380,  9.6856,  9.7263,  9.6715,  9.6734,  9.7285,
          9.7239,  9.7019],
        [ 9.6930,  9.7380, 10.1152,  9.7422,  9.7054,  9.6863,  9.6656,  9.7581,
          9.7344,  9.7060],
        [ 9.6691,  9.6856,  9.7422, 10.1050,  9.6962,  9.7747,  9.6686,  9.7395,
          9.7360,  9.7273],
        [ 9.6804,  9.7263,  9.7054,  9.6962, 10.0690,  9.6918,  9.7020,  9.7373,
          9.7213,  9.8145],
        [ 9.6918,  9.6715,  9.6863,  9.7747,  9.6918, 10.1390,  9.7411,  9.6886,
          9.7448,  9.7145],
        [ 9.7619,  9.6734,  9.6656,  9.6686,  9.7020,  9.7411, 10.2771,  9.6549,
          9.7216,  9.6779],
        [ 9.6733,  9.7285,  9.7581,  9.7395,  9.7373,  9.6886,  9.6549, 10.0832,
          9.7129,  9.7680],
        [ 9.7102,  9.7239,  9.7344,  9.7360,  9.7213,  9.7448,  9.7216,  9.7129,
         10.0018,  9.7372],
 

step0, f_mu[0]:  tensor([ -6.6810,  -6.1620,  -2.4085,  13.6550, -10.2654,  -2.3310, -17.4746,
         -6.1109,  -5.4316,   1.9322], device='cuda:0')
step0, f_var[0]:  tensor([[52.7796, 47.3584, 47.6517, 47.4591, 47.5492, 47.6385, 48.1929, 47.4940,
         47.7887, 47.6459],
        [47.3584, 52.1688, 48.0085, 47.5875, 47.9152, 47.4770, 47.4958, 47.9314,
         47.8984, 47.7169],
        [47.6517, 48.0085, 51.1746, 48.0451, 47.7461, 47.5932, 47.4318, 48.1738,
         47.9845, 47.7490],
        [47.4591, 47.5875, 48.0451, 51.0896, 47.6706, 48.3094, 47.4538, 48.0224,
         47.9963, 47.9246],
        [47.5492, 47.9152, 47.7461, 47.6706, 50.7950, 47.6371, 47.7243, 48.0029,
         47.8767, 48.6414],
        [47.6385, 47.4770, 47.5932, 48.3094, 47.6371, 51.3668, 48.0337, 47.6114,
         48.0687, 47.8225],
        [48.1929, 47.4958, 47.4318, 47.4538, 47.7243, 48.0337, 52.4715, 47.3466,
         47.8803, 47.5274],
        [47.4941, 47.9314, 48.1738, 48.0224, 48.0029, 47.6114, 47.34

step0, f_var[0]:  tensor([[54.0193, 50.0557, 50.2718, 50.1306, 50.1970, 50.2633, 50.6731, 50.1558,
         50.3731, 50.2678],
        [50.0557, 53.5681, 50.5355, 50.2262, 50.4667, 50.1441, 50.1568, 50.4790,
         50.4536, 50.3215],
        [50.2718, 50.5355, 52.8380, 50.5615, 50.3427, 50.2302, 50.1102, 50.6560,
         50.5165, 50.3452],
        [50.1306, 50.2262, 50.5615, 52.7763, 50.2876, 50.7549, 50.1270, 50.5450,
         50.5255, 50.4732],
        [50.1970, 50.4667, 50.3427, 50.2876, 52.5610, 50.2625, 50.3253, 50.5311,
         50.4380, 50.9962],
        [50.2633, 50.1441, 50.2302, 50.7549, 50.2625, 52.9790, 50.5537, 50.2436,
         50.5783, 50.3980],
        [50.6731, 50.1568, 50.1102, 50.1271, 50.3253, 50.5537, 53.7920, 50.0476,
         50.4403, 50.1812],
        [50.1558, 50.4790, 50.6560, 50.5451, 50.5311, 50.2436, 50.0476, 52.6463,
         50.3872, 50.7164],
        [50.3731, 50.4536, 50.5165, 50.5255, 50.4380, 50.5783, 50.4403, 50.3872,
         52.1614, 50.5340],
 

step0, f_var[0]:  tensor([[60.4692, 53.9518, 54.3053, 54.0735, 54.1820, 54.2899, 54.9587, 54.1153,
         54.4705, 54.2984],
        [53.9518, 59.7327, 54.7356, 54.2287, 54.6232, 54.0952, 54.1173, 54.6428,
         54.6025, 54.3846],
        [54.3053, 54.7356, 58.5355, 54.7792, 54.4196, 54.2355, 54.0404, 54.9341,
         54.7061, 54.4233],
        [54.0735, 54.2287, 54.7792, 58.4336, 54.3289, 55.0971, 54.0672, 54.7520,
         54.7204, 54.6342],
        [54.1820, 54.6232, 54.4196, 54.3289, 58.0795, 54.2883, 54.3927, 54.7288,
         54.5765, 55.4956],
        [54.2899, 54.0952, 54.2355, 55.0971, 54.2883, 58.7669, 54.7658, 54.2574,
         54.8074, 54.5112],
        [54.9587, 54.1173, 54.0404, 54.0672, 54.3927, 54.7658, 60.0978, 53.9379,
         54.5808, 54.1559],
        [54.1153, 54.6428, 54.9341, 54.7520, 54.7288, 54.2574, 53.9379, 58.2197,
         54.4928, 55.0343],
        [54.4705, 54.6025, 54.7061, 54.7204, 54.5765, 54.8074, 54.5808, 54.4928,
         57.4228, 54.7349],
 

step0, f_mu[0]:  tensor([ -9.0579,   1.6324,  17.5860,  -2.3321,  -8.1634, -13.4058, -21.3528,
          5.7247,  -6.7004,  -9.0593], device='cuda:0')
step0, f_var[0]:  tensor([[68.7074, 63.6448, 63.9170, 63.7385, 63.8220, 63.9052, 64.4202, 63.7707,
         64.0444, 63.9119],
        [63.6448, 68.1401, 64.2484, 63.8580, 64.1620, 63.7553, 63.7722, 64.1770,
         64.1463, 63.9778],
        [63.9170, 64.2484, 67.2178, 64.2823, 64.0050, 63.8630, 63.7131, 64.4019,
         64.2263, 64.0075],
        [63.7385, 63.8580, 64.2823, 67.1389, 63.9349, 64.5274, 63.7336, 64.2613,
         64.2371, 64.1707],
        [63.8220, 64.1620, 64.0050, 63.9349, 66.8647, 63.9038, 63.9846, 64.2430,
         64.1260, 64.8370],
        [63.9052, 63.7553, 63.8630, 64.5274, 63.9038, 67.3959, 64.2717, 63.8799,
         64.3044, 64.0759],
        [64.4203, 63.7722, 63.7131, 63.7336, 63.9846, 64.2717, 68.4213, 63.6341,
         64.1295, 63.8017],
        [63.7707, 64.1770, 64.4019, 64.2613, 64.2430, 63.8799, 63.63

step0, f_var[0]:  tensor([[27.9698, 22.9717, 23.2302, 23.0571, 23.1359, 23.2123, 23.6946, 23.0903,
         23.3496, 23.2244],
        [22.9717, 27.4370, 23.5408, 23.1652, 23.4590, 23.0713, 23.0933, 23.4712,
         23.4479, 23.2787],
        [23.2302, 23.5408, 26.5482, 23.5779, 23.3066, 23.1706, 23.0339, 23.6937,
         23.5270, 23.3070],
        [23.0571, 23.1652, 23.5779, 26.4683, 23.2375, 23.8180, 23.0503, 23.5567,
         23.5358, 23.4694],
        [23.1359, 23.4590, 23.3066, 23.2375, 26.1976, 23.2102, 23.2945, 23.5369,
         23.4268, 24.1314],
        [23.2123, 23.0713, 23.1706, 23.8180, 23.2102, 26.7203, 23.5644, 23.1873,
         23.6028, 23.3792],
        [23.6946, 23.0933, 23.0339, 23.0503, 23.2945, 23.5644, 27.6993, 22.9587,
         23.4319, 23.1152],
        [23.0903, 23.4712, 23.6937, 23.5567, 23.5369, 23.1873, 22.9587, 26.3055,
         23.3619, 23.7743],
        [23.3496, 23.4479, 23.5270, 23.5358, 23.4268, 23.6028, 23.4319, 23.3619,
         25.7023, 23.5503],
 

step0, f_mu[0]:  tensor([ 1.3164, -3.6425, -0.9631, -3.2580,  0.0189, -1.3374,  2.2960, -3.8168,
         2.1994, -2.2047], device='cuda:0')
step0, f_var[0]:  tensor([[10.6980,  5.6354,  5.8808,  5.7136,  5.7877,  5.8586,  6.3109,  5.7471,
          5.9934,  5.8751],
        [ 5.6354, 10.1970,  6.1722,  5.8119,  6.0953,  5.7260,  5.7511,  6.1049,
          6.0883,  5.9185],
        [ 5.8808,  6.1722,  9.3423,  6.2122,  5.9475,  5.8168,  5.6926,  6.3246,
          6.1668,  5.9449],
        [ 5.7136,  5.8119,  6.2122,  9.2612,  5.8789,  6.4464,  5.7053,  6.1910,
          6.1731,  6.1073],
        [ 5.7877,  6.0953,  5.9475,  5.8789,  8.9913,  5.8551,  5.9426,  6.1685,
          6.0657,  6.7681],
        [ 5.8586,  5.7260,  5.8168,  6.4464,  5.8551,  9.5075,  6.1959,  5.8332,
          6.2399,  6.0213],
        [ 6.3109,  5.7511,  5.6926,  5.7053,  5.9426,  6.1959, 10.4422,  5.6210,
          6.0729,  5.7663],
        [ 5.7471,  6.1049,  6.3246,  6.1910,  6.1685,  5.8332,  5.6210,  9.100

step0, f_var[0]:  tensor([[16.8208, 11.1695, 11.4605, 11.2654, 11.3540, 11.4399, 11.9822, 11.3029,
         11.5948, 11.4539],
        [11.1695, 16.2216, 11.8097, 11.3867, 11.7177, 11.2812, 11.3063, 11.7312,
         11.7055, 11.5143],
        [11.4605, 11.8097, 15.2205, 11.8519, 11.5459, 11.3928, 11.2394, 11.9823,
         11.7948, 11.5461],
        [11.2654, 11.3867, 11.8519, 15.1301, 11.4679, 12.1225, 11.2576, 11.8279,
         11.8045, 11.7295],
        [11.3540, 11.7177, 11.5459, 11.4679, 14.8243, 11.4373, 11.5328, 11.8052,
         11.6815, 12.4772],
        [11.4399, 11.2812, 11.3928, 12.1225, 11.4373, 15.4142, 11.8364, 11.4115,
         11.8801, 11.6280],
        [11.9822, 11.3063, 11.2394, 11.2576, 11.5328, 11.8364, 16.5165, 11.1546,
         11.6874, 11.3305],
        [11.3029, 11.7312, 11.9823, 11.8279, 11.8052, 11.4115, 11.1546, 14.9464,
         11.6082, 12.0737],
        [11.5948, 11.7055, 11.7948, 11.8045, 11.6815, 11.8801, 11.6874, 11.6082,
         14.2660, 11.8212],
 

step0, f_mu[0]:  tensor([-0.6002,  2.4073,  4.1011, -2.3116, -4.0396, -2.1470, -2.5534, -4.2742,
        -0.3344, -6.4289], device='cuda:0')
step0, f_var[0]:  tensor([[19.8613, 14.8057, 15.0650, 14.8912, 14.9702, 15.0468, 15.5301, 14.9246,
         15.1847, 15.0592],
        [14.8057, 19.3273, 15.3762, 14.9994, 15.2943, 14.9054, 14.9276, 15.3064,
         15.2834, 15.1130],
        [15.0650, 15.3762, 18.4355, 15.4138, 15.1412, 15.0047, 14.8681, 15.5301,
         15.3631, 15.1413],
        [14.8912, 14.9994, 15.4138, 18.3548, 15.0716, 15.6549, 14.8843, 15.3924,
         15.3716, 15.3049],
        [14.9702, 15.2943, 15.1412, 15.0716, 18.0819, 15.0444, 15.1296, 15.3721,
         15.2619, 15.9719],
        [15.0468, 14.9054, 15.0047, 15.6549, 15.0444, 18.6080, 15.3999, 15.0214,
         15.4391, 15.2144],
        [15.5301, 14.9276, 14.8681, 14.8843, 15.1296, 15.3999, 19.5902, 14.7926,
         15.2673, 14.9491],
        [14.9246, 15.3064, 15.5301, 15.3924, 15.3721, 15.0214, 14.7926, 18.191

step0, f_var[0]:  tensor([[6.6078, 3.2154, 3.3798, 3.2674, 3.3171, 3.3643, 3.6667, 3.2901, 3.4551,
         3.3759],
        [3.2154, 6.2726, 3.5747, 3.3328, 3.5232, 3.2755, 3.2930, 3.5294, 3.5188,
         3.4042],
        [3.3798, 3.5747, 5.6985, 3.6020, 3.4238, 3.3361, 3.2534, 3.6775, 3.5717,
         3.4220],
        [3.2674, 3.3328, 3.6020, 5.6437, 3.3777, 3.7596, 3.2616, 3.5876, 3.5758,
         3.5314],
        [3.3171, 3.5232, 3.4238, 3.3776, 5.4621, 3.3619, 3.4212, 3.5725, 3.5036,
         3.9766],
        [3.3643, 3.2755, 3.3361, 3.7596, 3.3619, 5.8096, 3.5907, 3.3472, 3.6209,
         3.4738],
        [3.6667, 3.2930, 3.2534, 3.2616, 3.4212, 3.5907, 6.4364, 3.2054, 3.5085,
         3.3026],
        [3.2901, 3.5294, 3.6775, 3.5876, 3.5725, 3.3472, 3.2054, 5.5358, 3.4597,
         3.7345],
        [3.4551, 3.5188, 3.5717, 3.5758, 3.5036, 3.6209, 3.5085, 3.4597, 5.1370,
         3.5885],
        [3.3759, 3.4042, 3.4220, 3.5314, 3.9766, 3.4738, 3.3026, 3.7345, 3.5885,
         5

step0, f_var[0]:  tensor([[18.9253, 13.6965, 13.9779, 13.7923, 13.8786, 13.9636, 14.4940, 13.8264,
         14.1089, 13.9720],
        [13.6965, 18.3407, 14.3193, 13.9143, 14.2297, 13.8090, 13.8285, 14.2448,
         14.2143, 14.0384],
        [13.9779, 14.3193, 17.3840, 14.3553, 14.0668, 13.9200, 13.7662, 14.4792,
         14.2976, 14.0694],
        [13.7923, 13.9143, 14.3553, 17.3016, 13.9940, 14.6104, 13.7867, 14.3332,
         14.3086, 14.2391],
        [13.8786, 14.2297, 14.0668, 13.9940, 17.0172, 13.9623, 14.0474, 14.3142,
         14.1932, 14.9322],
        [13.9636, 13.8090, 13.9200, 14.6104, 13.9623, 17.5691, 14.3437, 13.9376,
         14.3787, 14.1412],
        [14.4940, 13.8285, 13.7662, 13.7867, 14.0474, 14.3437, 18.6299, 13.6845,
         14.1971, 13.8574],
        [13.8264, 14.2448, 14.4792, 14.3332, 14.3142, 13.9375, 13.6845, 17.1297,
         14.1257, 14.5604],
        [14.1089, 14.2143, 14.2976, 14.3086, 14.1932, 14.3787, 14.1971, 14.1257,
         16.4905, 14.3209],
 

step0, f_mu[0]:  tensor([ 2.7414, -1.9384,  2.2260, -2.6212, -3.6777, -3.8004, -1.5636, -5.6638,
         0.9994, -2.1862], device='cuda:0')
step0, f_var[0]:  tensor([[25.0928, 17.9938, 18.3542, 18.1121, 18.2217, 18.3280, 18.9990, 18.1588,
         18.5207, 18.3462],
        [17.9938, 24.3512, 18.7864, 18.2618, 18.6726, 18.1317, 18.1633, 18.6890,
         18.6582, 18.4194],
        [18.3542, 18.7864, 23.1090, 18.8396, 18.4592, 18.2689, 18.0801, 19.0018,
         18.7697, 18.4585],
        [18.1121, 18.2618, 18.8396, 22.9958, 18.3617, 19.1760, 18.1021, 18.8096,
         18.7811, 18.6877],
        [18.2217, 18.6726, 18.4592, 18.3617, 22.6136, 18.3243, 18.4442, 18.7805,
         18.6278, 19.6219],
        [18.3280, 18.1317, 18.2689, 19.1760, 18.3243, 23.3492, 18.8196, 18.2923,
         18.8756, 18.5618],
        [18.9990, 18.1633, 18.0801, 18.1021, 18.4442, 18.8196, 24.7160, 17.9752,
         18.6358, 18.1921],
        [18.1588, 18.6890, 19.0018, 18.8096, 18.7805, 18.2923, 17.9752, 22.767

step0, f_var[0]:  tensor([[12.5849,  6.6105,  6.9012,  6.7031,  6.7908,  6.8747,  7.4102,  6.7427,
          7.0344,  6.8943],
        [ 6.6105, 11.9916,  7.2462,  6.8193,  7.1551,  6.7176,  6.7475,  7.1664,
          7.1468,  6.9457],
        [ 6.9012,  7.2462, 10.9791,  7.2936,  6.9799,  6.8252,  6.6781,  7.4267,
          7.2397,  6.9770],
        [ 6.7031,  6.8193,  7.2936, 10.8829,  6.8987,  7.5711,  6.6931,  7.2684,
          7.2473,  7.1692],
        [ 6.7908,  7.1551,  6.9799,  6.8987, 10.5636,  6.8706,  6.9743,  7.2420,
          7.1202,  7.9516],
        [ 6.8747,  6.7176,  6.8252,  7.5711,  6.8706, 11.1749,  7.2743,  6.8446,
          7.3264,  7.0674],
        [ 7.4102,  6.7475,  6.6781,  6.6931,  6.9743,  7.2743, 12.2819,  6.5933,
          7.1286,  6.7654],
        [ 6.7427,  7.1664,  7.4267,  7.2684,  7.2420,  6.8446,  6.5933, 10.6931,
          7.0431,  7.5265],
        [ 7.0344,  7.1468,  7.2397,  7.2473,  7.1202,  7.3264,  7.1286,  7.0431,
          9.9913,  7.2691],
 

step0, f_mu[0]:  tensor([ 0.8565, -2.8002, -0.5210, -3.0929, -1.0463, -2.5566,  3.2171, -6.3834,
         3.0237, -2.6628], device='cuda:0')
step0, f_var[0]:  tensor([[16.8474,  9.3357,  9.6941,  9.4486,  9.5565,  9.6592, 10.3173,  9.4984,
          9.8581,  9.6856],
        [ 9.3357, 16.1179, 10.1182,  9.5901, 10.0061,  9.4660,  9.5048, 10.0192,
          9.9975,  9.7453],
        [ 9.6941, 10.1182, 14.8642, 10.1788,  9.7887,  9.5971,  9.4183, 10.3440,
         10.1136,  9.7838],
        [ 9.4486,  9.5901, 10.1788, 14.7433,  9.6872, 10.5239,  9.4354, 10.1472,
         10.1220, 10.0245],
        [ 9.5565, 10.0061,  9.7887,  9.6872, 14.3436,  9.6536,  9.7847, 10.1130,
          9.9635, 11.0040],
        [ 9.6592,  9.4660,  9.5971, 10.5239,  9.6536, 15.1065, 10.1533,  9.6214,
         10.2211,  9.8988],
        [10.3173,  9.5048,  9.4183,  9.4354,  9.7847, 10.1533, 16.4742,  9.3136,
          9.9749,  9.5244],
        [ 9.4984, 10.0192, 10.3440, 10.1472, 10.1130,  9.6214,  9.3136, 14.506

step0, f_var[0]:  tensor([[21.8925, 14.1653, 14.5728, 14.3016, 14.4262, 14.5476, 15.3111, 14.3527,
         14.7615, 14.5638],
        [14.1653, 21.0498, 15.0644, 14.4746, 14.9350, 14.3246, 14.3568, 14.9553,
         14.9154, 14.6540],
        [14.5728, 15.0644, 19.6554, 15.1201, 14.6966, 14.4832, 14.2645, 15.3012,
         15.0382, 14.6989],
        [14.3016, 14.4746, 15.1201, 19.5324, 14.5894, 15.4947, 14.2921, 15.0872,
         15.0531, 14.9501],
        [14.4262, 14.9350, 14.6966, 14.5894, 19.1127, 14.5450, 14.6737, 15.0579,
         14.8834, 15.9755],
        [14.5476, 14.3246, 14.4832, 15.4947, 14.5450, 19.9254, 15.1008, 14.5090,
         15.1569, 14.8081],
        [15.3111, 14.3568, 14.2645, 14.2921, 14.6737, 15.1008, 21.4656, 14.1460,
         14.8903, 14.3944],
        [14.3527, 14.9553, 15.3012, 15.0872, 15.0579, 14.5090, 14.1460, 19.2791,
         14.7832, 15.4238],
        [14.7615, 14.9154, 15.0382, 15.0531, 14.8834, 15.1569, 14.8903, 14.7832,
         18.3397, 15.0735],
 

step0, f_mu[0]:  tensor([ 0.3529,  2.2088,  7.0155, -3.2133, -1.8579, -3.9718, -0.7929, -4.7171,
        -2.6040, -8.8867], device='cuda:0')
step0, f_var[0]:  tensor([[25.7884, 19.7394, 20.0536, 19.8437, 19.9396, 20.0329, 20.6201, 19.8836,
         20.1989, 20.0467],
        [19.7394, 25.1399, 20.4317, 19.9760, 20.3322, 19.8613, 19.8870, 20.3474,
         20.3181, 20.1139],
        [20.0536, 20.4317, 24.0619, 20.4760, 20.1475, 19.9824, 19.8154, 20.6163,
         20.4138, 20.1483],
        [19.8437, 19.9760, 20.4760, 23.9655, 20.0639, 20.7665, 19.8359, 20.4504,
         20.4246, 20.3445],
        [19.9396, 20.3322, 20.1475, 20.0639, 23.6380, 20.0304, 20.1317, 20.4266,
         20.2927, 21.1445],
        [20.0329, 19.8613, 19.9824, 20.7665, 20.0304, 24.2705, 20.4601, 20.0026,
         20.5056, 20.2349],
        [20.6201, 19.8870, 19.8154, 19.8358, 20.1317, 20.4601, 25.4595, 19.7240,
         20.2987, 19.9146],
        [19.8836, 20.3474, 20.6163, 20.4504, 20.4266, 20.0026, 19.7240, 23.768

step0, f_mu[0]:  tensor([ 1.8639, -2.7518, -1.9417, -3.7276, -3.3728, -1.1933, -4.3931,  1.3613,
        -4.2703, -1.7572], device='cuda:0')
step0, f_var[0]:  tensor([[30.2165, 15.7321, 16.4157, 15.9456, 16.1511, 16.3459, 17.5977, 16.0418,
         16.7279, 16.3993],
        [15.7321, 28.8281, 17.2225, 16.2129, 17.0090, 15.9782, 16.0548, 17.0330,
         16.9948, 16.5084],
        [16.4157, 17.2225, 26.4305, 17.3407, 16.5926, 16.2263, 15.8885, 17.6573,
         17.2182, 16.5815],
        [15.9456, 16.2129, 17.3407, 26.1968, 16.3973, 18.0027, 15.9194, 17.2800,
         17.2330, 17.0454],
        [16.1511, 17.0090, 16.5926, 16.3973, 25.4273, 16.3344, 16.5888, 17.2126,
         16.9286, 18.9322],
        [16.3459, 15.9782, 16.2263, 18.0027, 16.3344, 26.8938, 17.2902, 16.2729,
         17.4239, 16.8055],
        [17.5977, 16.0548, 15.8885, 15.9194, 16.5888, 17.2902, 29.5053, 15.6887,
         16.9515, 16.0889],
        [16.0418, 17.0330, 17.6573, 17.2800, 17.2126, 16.2729, 15.6887, 25.742

step0, f_mu[0]:  tensor([ 1.7994,  0.7980,  0.6430,  1.5861, -5.1248, -0.5785, -1.2213, -1.7858,
        -2.7957,  0.3945], device='cuda:0')
step0, f_var[0]:  tensor([[7.7522, 2.7278, 2.9777, 2.8086, 2.8843, 2.9571, 3.4201, 2.8418, 3.0927,
         2.9719],
        [2.7278, 7.2400, 3.2760, 2.9104, 3.1972, 2.8215, 2.8455, 3.2078, 3.1887,
         3.0195],
        [2.9777, 3.2760, 6.3738, 3.3147, 3.0480, 2.9155, 2.7867, 3.4281, 3.2672,
         3.0466],
        [2.8086, 2.9104, 3.3147, 6.2933, 2.9793, 3.5507, 2.8008, 3.2935, 3.2745,
         3.2085],
        [2.8843, 3.1972, 3.0480, 2.9793, 6.0238, 2.9542, 3.0403, 3.2722, 3.1667,
         3.8682],
        [2.9571, 2.8215, 2.9155, 3.5507, 2.9542, 6.5414, 3.2995, 2.9320, 3.3412,
         3.1211],
        [3.4201, 2.8455, 2.7867, 2.8008, 3.0403, 3.2995, 7.4913, 2.7138, 3.1730,
         2.8632],
        [2.8418, 3.2078, 3.4281, 3.2935, 3.2722, 2.9320, 2.7138, 6.1324, 3.1019,
         3.5108],
        [3.0927, 3.1887, 3.2672, 3.2745, 3.1667, 

step0, f_mu[0]:  tensor([-2.4701, -6.4974, -9.5541, -8.5480,  6.6171, -2.6692,  1.4985, -0.4203,
         0.7185, -5.7228], device='cuda:0')
step0, f_var[0]:  tensor([[64.8142, 43.7263, 44.8011, 44.0806, 44.4075, 44.7254, 46.7290, 44.2188,
         45.2981, 44.7776],
        [43.7263, 62.6005, 46.0914, 44.5288, 45.7521, 44.1397, 44.2316, 45.8018,
         45.7076, 44.9989],
        [44.8011, 46.0914, 58.9012, 46.2481, 45.1168, 44.5500, 43.9848, 46.7310,
         46.0387, 45.1156],
        [44.0806, 44.5288, 46.2481, 58.5655, 44.8272, 47.2484, 44.0515, 46.1593,
         46.0734, 45.7962],
        [44.4075, 45.7521, 45.1168, 44.8272, 57.4297, 44.7149, 45.0695, 46.0734,
         45.6175, 48.5706],
        [44.7254, 44.1397, 44.5500, 47.2484, 44.7148, 59.6164, 46.1898, 44.6195,
         46.3539, 45.4209],
        [46.7290, 44.2316, 43.9848, 44.0515, 45.0695, 46.1898, 63.6899, 43.6719,
         45.6407, 44.3199],
        [44.2188, 45.8018, 46.7310, 46.1593, 46.0734, 44.6195, 43.6719, 57.885

step0, f_mu[0]:  tensor([-4.0005, -9.5164, -6.1010,  0.7874, -6.6176,  1.8897, -6.7852,  0.0236,
         0.5745, -1.3376], device='cuda:0')
step0, f_var[0]:  tensor([[74.0684, 49.8566, 51.0585, 50.2433, 50.6073, 50.9559, 53.1779, 50.4047,
         51.6102, 51.0303],
        [49.8566, 71.6090, 52.4899, 50.7299, 52.1117, 50.3047, 50.4229, 52.1614,
         52.0726, 51.2542],
        [51.0584, 52.4899, 67.4377, 52.6791, 51.3921, 50.7547, 50.1388, 53.2257,
         52.4519, 51.3846],
        [50.2433, 50.7299, 52.6791, 67.0478, 51.0609, 53.8177, 50.2048, 52.5765,
         52.4861, 52.1673],
        [50.6073, 52.1117, 51.3921, 51.0609, 65.7465, 50.9415, 51.3595, 52.4727,
         51.9663, 55.3551],
        [50.9559, 50.3047, 50.7547, 53.8177, 50.9415, 68.2448, 52.6040, 50.8342,
         52.8087, 51.7470],
        [53.1779, 50.4229, 50.1388, 50.2048, 51.3595, 52.6040, 72.8145, 49.7882,
         51.9975, 50.5048],
        [50.4047, 52.1614, 53.2257, 52.5764, 52.4727, 50.8342, 49.7882, 66.270

step0, f_mu[0]:  tensor([-1.4143,  4.2749, -2.9512, -5.8806,  2.5261, -3.0280, -0.7995, -3.4435,
         0.9578, -4.7155], device='cuda:0')
step0, f_var[0]:  tensor([[23.4808, 17.4220, 17.7386, 17.5275, 17.6242, 17.7184, 18.3110, 17.5675,
         17.8852, 17.7316],
        [17.4220, 22.8266, 18.1201, 17.6614, 18.0197, 17.5453, 17.5707, 18.0353,
         18.0050, 17.8006],
        [17.7386, 18.1201, 21.7419, 18.1640, 17.8341, 17.6679, 17.4989, 18.3051,
         18.1009, 17.8353],
        [17.5275, 17.6614, 18.1640, 21.6456, 17.7503, 18.4558, 17.5199, 18.1384,
         18.1121, 18.0318],
        [17.6242, 18.0197, 17.8341, 17.7503, 21.3173, 17.7161, 17.8171, 18.1149,
         17.9797, 18.8334],
        [17.7184, 17.5453, 17.6679, 18.4558, 17.7161, 21.9518, 18.1485, 17.6881,
         18.1932, 17.9214],
        [18.3110, 17.5707, 17.4989, 17.5199, 17.8171, 18.1485, 23.1492, 17.4067,
         17.9854, 17.5992],
        [17.5675, 18.0353, 18.3051, 18.1384, 18.1149, 17.6881, 17.4067, 21.448

step0, f_mu[0]:  tensor([-3.8305, -9.8095,  0.7674, -1.5661, -2.6121,  0.3019, -0.8441, -3.3783,
         2.2514, -1.8043], device='cuda:0')
step0, f_var[0]:  tensor([[40.0768, 26.0700, 26.7714, 26.2985, 26.5113, 26.7170, 28.0194, 26.3906,
         27.0948, 26.7558],
        [26.0700, 38.6366, 27.6103, 26.5869, 27.3894, 26.3359, 26.4001, 27.4200,
         27.3636, 26.8929],
        [26.7715, 27.6103, 36.2122, 27.7168, 26.9720, 26.6008, 26.2369, 28.0341,
         27.5826, 26.9689],
        [26.2985, 26.5869, 27.7168, 35.9885, 26.7805, 28.3755, 26.2778, 27.6579,
         27.6035, 27.4201],
        [26.5113, 27.3894, 26.9720, 26.7805, 35.2366, 26.7091, 26.9471, 27.5989,
         27.3028, 29.2584],
        [26.7170, 26.3359, 26.6008, 28.3755, 26.7091, 36.6809, 27.6757, 26.6467,
         27.7894, 27.1749],
        [28.0194, 26.4001, 26.2369, 26.2778, 26.9471, 27.6757, 39.3440, 26.0325,
         27.3197, 26.4527],
        [26.3906, 27.4200, 28.0341, 27.6579, 27.5989, 26.6467, 26.0325, 35.539

step0, f_mu[0]:  tensor([-8.0556, -1.2172,  2.6364, -2.6413,  3.5363,  4.4929, -7.6903, -2.7329,
        -5.2011, -1.5280], device='cuda:0')
step0, f_var[0]:  tensor([[39.5949, 30.5384, 31.0171, 30.7002, 30.8467, 30.9908, 31.8911, 30.7590,
         31.2399, 31.0073],
        [30.5384, 38.6021, 31.5966, 30.9060, 31.4447, 30.7283, 30.7629, 31.4695,
         31.4202, 31.1165],
        [31.0171, 31.5966, 36.9696, 31.6601, 31.1660, 30.9154, 30.6563, 31.8721,
         31.5634, 31.1688],
        [30.7002, 30.9060, 31.6601, 36.8268, 31.0406, 32.0970, 30.6899, 31.6221,
         31.5812, 31.4617],
        [30.8467, 31.4448, 31.1660, 31.0406, 36.3364, 30.9877, 31.1359, 31.5879,
         31.3830, 32.6568],
        [30.9908, 30.7283, 30.9154, 32.0970, 30.9877, 37.2852, 31.6387, 30.9455,
         31.7020, 31.2948],
        [31.8911, 30.7629, 30.6563, 30.6899, 31.1359, 31.6387, 39.0928, 30.5172,
         31.3906, 30.8098],
        [30.7590, 31.4695, 31.8721, 31.6221, 31.5879, 30.9455, 30.5172, 36.531

step0, f_mu[0]:  tensor([-2.6489, -2.7560, -0.7307, -0.7400,  0.1172,  1.0975,  1.7386, -2.5281,
        -2.2741, -2.9902], device='cuda:0')
step0, f_var[0]:  tensor([[7.1988, 5.3198, 5.4163, 5.3517, 5.3811, 5.4098, 5.5900, 5.3640, 5.4609,
         5.4142],
        [5.3198, 6.9998, 5.5323, 5.3923, 5.5018, 5.3571, 5.3651, 5.5064, 5.4976,
         5.4345],
        [5.4163, 5.5323, 6.6684, 5.5461, 5.4450, 5.3942, 5.3431, 5.5893, 5.5271,
         5.4450],
        [5.3517, 5.3923, 5.5461, 6.6386, 5.4191, 5.6355, 5.3493, 5.5382, 5.5303,
         5.5056],
        [5.3811, 5.5018, 5.4450, 5.4191, 6.5373, 5.4089, 5.4403, 5.5306, 5.4896,
         5.7528],
        [5.4098, 5.3571, 5.3942, 5.6355, 5.4089, 6.7325, 5.5410, 5.4004, 5.5553,
         5.4720],
        [5.5900, 5.3651, 5.3431, 5.3493, 5.4403, 5.5410, 7.0978, 5.3150, 5.4916,
         5.3734],
        [5.3640, 5.5064, 5.5893, 5.5382, 5.5306, 5.4004, 5.3150, 6.5779, 5.4654,
         5.6194],
        [5.4609, 5.4976, 5.5271, 5.5303, 5.4896, 

step0, f_var[0]:  tensor([[20.1079,  7.1109,  7.7423,  7.3112,  7.5018,  7.6832,  8.8451,  7.3979,
          8.0313,  7.7272],
        [ 7.1109, 18.8205,  8.4910,  7.5625,  8.2929,  7.3423,  7.4087,  8.3171,
          8.2759,  7.8369],
        [ 7.7423,  8.4910, 16.6181,  8.5950,  7.9118,  7.5754,  7.2573,  8.8845,
          8.4784,  7.9051],
        [ 7.3112,  7.5625,  8.5950, 16.4081,  7.7349,  9.1995,  7.2891,  8.5400,
          8.4945,  8.3242],
        [ 7.5018,  8.2929,  7.9118,  7.7349, 15.7123,  7.6743,  7.9011,  8.4821,
          8.2177, 10.0299],
        [ 7.6832,  7.3423,  7.5754,  9.1995,  7.6743, 17.0441,  8.5522,  7.6177,
          8.6671,  8.1030],
        [ 8.8451,  7.4087,  7.2573,  7.2891,  7.9011,  8.5522, 19.4501,  7.0729,
          8.2362,  7.4462],
        [ 7.3979,  8.3171,  8.8845,  8.5400,  8.4821,  7.6177,  7.0729, 15.9943,
          8.0495,  9.1027],
        [ 8.0313,  8.2759,  8.4784,  8.4945,  8.2177,  8.6671,  8.2362,  8.0495,
         14.4658,  8.5425],
 

step0, f_mu[0]:  tensor([ -2.9086,  -9.5438,  -7.7035, -10.5020,   2.6750,   2.4523,   5.8661,
         -1.1683,   2.9740,  -7.1775], device='cuda:0')
step0, f_var[0]:  tensor([[50.6579, 35.2192, 35.9944, 35.4730, 35.7084, 35.9367, 37.3786, 35.5739,
         36.3524, 35.9774],
        [35.2192, 49.0640, 36.9230, 35.7937, 36.6788, 35.5149, 35.5838, 36.7134,
         36.6488, 36.1321],
        [35.9944, 36.9230, 46.3893, 37.0388, 36.2187, 35.8089, 35.4045, 37.3886,
         36.8896, 36.2160],
        [35.4730, 35.7937, 37.0388, 46.1440, 36.0080, 37.7641, 35.4509, 36.9742,
         36.9134, 36.7119],
        [35.7084, 36.6788, 36.2187, 36.0080, 45.3169, 35.9283, 36.1884, 36.9100,
         36.5822, 38.7324],
        [35.9367, 35.5149, 35.8089, 37.7641, 35.9283, 46.9063, 36.9948, 35.8594,
         37.1176, 36.4410],
        [37.3786, 35.5838, 35.4045, 35.4509, 36.1884, 36.9948, 49.8476, 35.1787,
         36.6004, 35.6441],
        [35.5739, 36.7134, 37.3886, 36.9742, 36.9099, 35.8594, 35.17

step0, f_mu[0]:  tensor([ -7.8167, -11.0350,  -6.9416,  -6.8650,  -2.0037,   5.4131,   2.3968,
         -2.2979,   1.4035,  -4.3296], device='cuda:0')
step0, f_var[0]:  tensor([[70.2749, 49.4192, 50.4656, 49.7609, 50.0786, 50.3861, 52.3307, 49.8978,
         50.9484, 50.4424],
        [49.4192, 68.1249, 51.7180, 50.1925, 51.3884, 49.8171, 49.9114, 51.4347,
         51.3489, 50.6492],
        [50.4656, 51.7180, 64.5114, 51.8756, 50.7666, 50.2131, 49.6687, 52.3483,
         51.6747, 50.7625],
        [49.7609, 50.1925, 51.8756, 64.1790, 50.4816, 52.8565, 49.7306, 51.7881,
         51.7065, 51.4337],
        [50.0786, 51.3884, 50.7666, 50.4816, 63.0602, 50.3745, 50.7275, 51.7008,
         51.2587, 54.1681],
        [50.3861, 49.8171, 50.2131, 52.8565, 50.3745, 65.2099, 51.8152, 50.2814,
         51.9829, 51.0679],
        [52.3307, 49.9114, 49.6687, 49.7306, 50.7275, 51.8152, 69.1814, 49.3639,
         51.2835, 49.9915],
        [49.8978, 51.4347, 52.3483, 51.7881, 51.7008, 50.2814, 49.36

step0, f_mu[0]:  tensor([-11.6415,  -3.3189,  -0.4398,  -5.9514,   9.5302,  -3.5312, -10.5293,
         -2.9876,  -2.6965,  -0.4369], device='cuda:0')
step0, f_var[0]:  tensor([[48.6324, 38.7656, 39.2829, 38.9402, 39.0983, 39.2542, 40.2267, 39.0038,
         39.5238, 39.2725],
        [38.7656, 47.5598, 39.9089, 39.1623, 39.7451, 38.9706, 39.0082, 39.7715,
         39.7190, 39.3893],
        [39.2829, 39.9090, 45.7942, 39.9783, 39.4433, 39.1721, 38.8929, 40.2078,
         39.8744, 39.4456],
        [38.9402, 39.1623, 39.9783, 45.6389, 39.3072, 40.4511, 38.9288, 39.9371,
         39.8931, 39.7636],
        [39.0983, 39.7451, 39.4433, 39.3072, 45.1062, 39.2504, 39.4117, 39.8992,
         39.6782, 41.0613],
        [39.2542, 38.9706, 39.1721, 40.4511, 39.2504, 46.1354, 39.9546, 39.2048,
         40.0242, 39.5832],
        [40.2267, 39.0082, 38.8929, 38.9288, 39.4117, 39.9546, 48.0898, 38.7426,
         39.6869, 39.0581],
        [39.0038, 39.7715, 40.2078, 39.9371, 39.8992, 39.2048, 38.74

step0, f_mu[0]:  tensor([  2.3143,   0.6924,   3.5298, -18.0034,  11.2822,  -6.3529,   1.9438,
         -4.0761,  -5.1001,  -8.7115], device='cuda:0')
step0, f_var[0]:  tensor([[69.7957, 54.6307, 55.4279, 54.8985, 55.1422, 55.3811, 56.8770, 54.9976,
         55.7981, 55.4111],
        [54.6307, 68.1453, 56.3909, 55.2388, 56.1382, 54.9445, 55.0048, 56.1784,
         56.0990, 55.5891],
        [55.4279, 56.3909, 65.4211, 56.4990, 55.6724, 55.2546, 54.8259, 56.8530,
         56.3390, 55.6762],
        [54.8986, 55.2388, 56.4990, 65.1809, 55.4626, 57.2297, 54.8804, 56.4351,
         56.3679, 56.1673],
        [55.1422, 56.1382, 55.6724, 55.4626, 64.3595, 55.3755, 55.6258, 56.3770,
         56.0362, 58.1711],
        [55.3811, 54.9445, 55.2546, 57.2297, 55.3755, 65.9479, 56.4617, 55.3051,
         56.5705, 55.8894],
        [56.8770, 55.0048, 54.8259, 54.8804, 55.6258, 56.4617, 68.9602, 54.5941,
         56.0497, 55.0801],
        [54.9976, 56.1784, 56.8530, 56.4351, 56.3770, 55.3051, 54.59

step0, f_mu[0]:  tensor([ -9.6062, -16.4530,   2.9010,   4.1319,  -8.5450,   7.5816, -11.4210,
         -1.8115,   1.0709,  -6.8856], device='cuda:0')
step0, f_var[0]:  tensor([[110.0123,  72.2790,  74.1320,  72.8711,  73.4315,  73.9668,  77.3850,
          73.1227,  74.9817,  74.0883],
        [ 72.2790, 106.2269,  76.3344,  73.6155,  75.7523,  72.9643,  73.1525,
          75.8262,  75.6969,  74.4222],
        [ 74.1320,  76.3344,  99.7799,  76.6328,  74.6384,  73.6534,  72.7115,
          77.4790,  76.2865,  74.6227],
        [ 72.8711,  73.6155,  76.6328,  99.1714,  74.1235,  78.3972,  72.8091,
          76.4729,  76.3364,  75.8409],
        [ 73.4315,  75.7523,  74.6384,  74.1235,  97.1473,  73.9425,  74.5972,
          76.3080,  75.5296,  80.8010],
        [ 73.9668,  72.9643,  73.6534,  78.3972,  73.9425, 101.0270,  76.5121,
          73.7767,  76.8384,  75.1924],
        [ 77.3850,  73.1525,  72.7115,  72.8091,  74.5972,  76.5121, 108.0803,
          72.1707,  75.5808,  73.2710]

step0, f_mu[0]:  tensor([-1.2440,  3.8976, -1.2188, -5.9078,  3.5948, -3.9232, -1.0289, -1.1143,
         0.7693, -3.8575], device='cuda:0')
step0, f_var[0]:  tensor([[23.2517, 18.0691, 18.3405, 18.1600, 18.2429, 18.3239, 18.8326, 18.1939,
         18.4664, 18.3347],
        [18.0691, 22.6903, 18.6680, 18.2753, 18.5819, 18.1755, 18.1965, 18.5954,
         18.5690, 18.3945],
        [18.3405, 18.6680, 21.7618, 18.7052, 18.4231, 18.2808, 18.1353, 18.8259,
         18.6509, 18.4242],
        [18.1600, 18.2753, 18.7052, 21.6796, 18.3514, 18.9546, 18.1536, 18.6833,
         18.6606, 18.5921],
        [18.2429, 18.5819, 18.4231, 18.3514, 21.3990, 18.3220, 18.4079, 18.6633,
         18.5474, 19.2768],
        [18.3239, 18.1755, 18.2808, 18.9546, 18.3220, 21.9414, 18.6922, 18.2980,
         18.7299, 18.4975],
        [18.8326, 18.1965, 18.1353, 18.1536, 18.4079, 18.6922, 22.9673, 18.0564,
         18.5522, 18.2216],
        [18.1939, 18.5954, 18.8259, 18.6833, 18.6633, 18.2980, 18.0564, 21.510

step0, f_mu[0]:  tensor([-8.1286, -1.5620,  3.3657, -4.2695,  3.6764, -0.8785, -1.5476, -0.6552,
         0.1172, -6.8738], device='cuda:0')
step0, f_var[0]:  tensor([[20.0496, 11.2105, 11.6480, 11.3511, 11.4836, 11.6105, 12.4192, 11.4100,
         11.8489, 11.6377],
        [11.2105, 19.1543, 12.1691, 11.5281, 12.0314, 11.3734, 11.4167, 12.0494,
         12.0173, 11.7189],
        [11.6480, 12.1691, 17.6352, 12.2381, 11.7693, 11.5372, 11.3131, 12.4372,
         12.1555, 11.7664],
        [11.3511, 11.5281, 12.2381, 17.4931, 11.6486, 12.6529, 11.3370, 12.2007,
         12.1679, 12.0517],
        [11.4836, 12.0314, 11.7693, 11.6486, 17.0187, 11.6052, 11.7576, 12.1627,
         11.9784, 13.2136],
        [11.6105, 11.3734, 11.5372, 12.6529, 11.6052, 17.9291, 12.2107, 11.5661,
         12.2854, 11.8986],
        [12.4192, 11.4167, 11.3131, 11.3370, 11.7576, 12.2107, 19.5931, 11.1855,
         11.9899, 11.4462],
        [11.4100, 12.0494, 12.4372, 12.2007, 12.1627, 11.5661, 11.1855, 17.210

step0, f_mu[0]:  tensor([-5.0628, -1.3775,  0.1699, -4.6908,  5.2528, -3.4406, -0.0125, -1.4704,
         0.2413, -2.3427], device='cuda:0')
step0, f_var[0]:  tensor([[27.2352, 13.7885, 14.4372, 13.9942, 14.1899, 14.3765, 15.5701, 14.0833,
         14.7344, 14.4219],
        [13.7885, 25.9126, 15.2064, 14.2523, 15.0031, 14.0262, 14.0944, 15.0278,
         14.9859, 14.5339],
        [14.4372, 15.2064, 23.6490, 15.3136, 14.6112, 14.2653, 13.9389, 15.6115,
         15.1944, 14.6037],
        [13.9942, 14.2523, 15.3136, 23.4327, 14.4290, 15.9351, 13.9713, 15.2571,
         15.2105, 15.0354],
        [14.1899, 15.0031, 14.6113, 14.4290, 22.7155, 14.3670, 14.6007, 15.1968,
         14.9255, 16.7926],
        [14.3765, 14.0262, 14.2653, 15.9351, 14.3670, 24.0867, 15.2693, 14.3088,
         15.3882, 14.8081],
        [15.5701, 14.0944, 13.9389, 13.9713, 14.6007, 15.2693, 26.5593, 13.7495,
         14.9450, 14.1325],
        [14.0833, 15.0278, 15.6115, 15.2571, 15.1968, 14.3088, 13.7495, 23.007

step0, f_mu[0]:  tensor([-6.9772,  0.7497,  0.9581, -3.2071, -2.8345, -3.2445,  3.2356, -1.9079,
        -5.1500, -5.8293], device='cuda:0')
step0, f_var[0]:  tensor([[35.6336, 21.4591, 22.1716, 21.6920, 21.9084, 22.1180, 23.4428, 21.7850,
         22.5005, 22.1559],
        [21.4591, 34.1691, 23.0248, 21.9864, 22.8003, 21.7305, 21.7942, 22.8320,
         22.7731, 22.2974],
        [22.1716, 23.0248, 31.7093, 23.1318, 22.3771, 22.0003, 21.6292, 23.4535,
         22.9948, 22.3745],
        [21.6920, 21.9864, 23.1318, 31.4833, 22.1832, 23.7991, 21.6715, 23.0722,
         23.0165, 22.8309],
        [21.9084, 22.8003, 22.3771, 22.1832, 30.7224, 22.1101, 22.3500, 23.0130,
         22.7119, 24.6907],
        [22.1180, 21.7305, 22.0003, 23.7991, 22.1101, 32.1848, 23.0909, 22.0468,
         23.2045, 22.5819],
        [23.4428, 21.7942, 21.6292, 21.6715, 22.3500, 23.0909, 34.8889, 21.4216,
         22.7286, 21.8492],
        [21.7850, 22.8320, 23.4535, 23.0722, 23.0130, 22.0468, 21.4216, 31.029

step0, f_mu[0]:  tensor([-1.1464, -0.4868,  0.4405, -4.5710,  3.0008, -4.4484, -2.4601,  2.8324,
        -4.2149, -4.1581], device='cuda:0')
step0, f_var[0]:  tensor([[19.9562, 12.1252, 12.5220, 12.2552, 12.3758, 12.4926, 13.2307, 12.3069,
         12.7051, 12.5131],
        [12.1252, 19.1402, 12.9974, 12.4195, 12.8722, 12.2766, 12.3119, 12.8901,
         12.8566, 12.5930],
        [12.5220, 12.9974, 17.7718, 13.0563, 12.6370, 12.4275, 12.2201, 13.2351,
         12.9795, 12.6362],
        [12.2552, 12.4195, 13.0563, 17.6468, 12.5296, 13.4272, 12.2441, 13.0233,
         12.9920, 12.8890],
        [12.3758, 12.8722, 12.6370, 12.5296, 17.2253, 12.4886, 12.6211, 12.9910,
         12.8230, 13.9194],
        [12.4926, 12.2766, 12.4275, 13.4272, 12.4886, 18.0365, 13.0341, 12.4533,
         13.0962, 12.7504],
        [13.2307, 12.3119, 12.2201, 12.2441, 12.6211, 13.0341, 19.5414, 12.1045,
         12.8319, 12.3432],
        [12.3069, 12.8901, 13.2351, 13.0233, 12.9910, 12.4533, 12.1045, 17.394

step0, f_var[0]:  tensor([[34.7407, 20.8771, 21.5646, 21.0987, 21.3070, 21.5069, 22.7790, 21.1907,
         21.8806, 21.5487],
        [20.8771, 33.3329, 22.3841, 21.3779, 22.1678, 21.1342, 21.2009, 22.1964,
         22.1450, 21.6776],
        [21.5646, 22.3841, 30.9482, 22.4918, 21.7565, 21.3918, 21.0389, 22.8043,
         22.3617, 21.7521],
        [21.0987, 21.3779, 22.4918, 30.7256, 21.5671, 23.1422, 21.0770, 22.4332,
         22.3813, 22.1993],
        [21.3070, 22.1678, 21.7565, 21.5671, 29.9816, 21.4986, 21.7369, 22.3737,
         22.0842, 24.0209],
        [21.5069, 21.1342, 21.3918, 23.1422, 21.4986, 31.4094, 22.4492, 21.4372,
         22.5656, 21.9589],
        [22.7790, 21.2009, 21.0389, 21.0770, 21.7369, 22.4492, 34.0232, 20.8384,
         22.1020, 21.2484],
        [21.1907, 22.1964, 22.8043, 22.4332, 22.3737, 21.4372, 20.8384, 30.2817,
         21.9050, 23.0334],
        [21.8806, 22.1450, 22.3617, 22.3813, 22.0842, 22.5656, 22.1020, 21.9050,
         28.6398, 22.4289],
 

step0, f_mu[0]:  tensor([ -6.1974,  -8.1680,  -3.6531,   8.6992,  -4.7258,   0.1661, -13.3705,
        -10.3417,  -3.3970,   4.0022], device='cuda:0')
step0, f_var[0]:  tensor([[51.6801, 44.0883, 44.4895, 44.2235, 44.3462, 44.4666, 45.2202, 44.2731,
         44.6759, 44.4811],
        [44.0883, 50.8489, 44.9747, 44.3953, 44.8474, 44.2467, 44.2766, 44.8679,
         44.8271, 44.5718],
        [44.4895, 44.9747, 49.4792, 45.0284, 44.6134, 44.4033, 44.1868, 45.2062,
         44.9475, 44.6157],
        [44.2235, 44.3953, 45.0284, 49.3591, 44.5082, 45.3953, 44.2146, 44.9964,
         44.9623, 44.8618],
        [44.3462, 44.8473, 44.6134, 44.5082, 48.9476, 44.4640, 44.5890, 44.9677,
         44.7959, 45.8657],
        [44.4666, 44.2467, 44.4033, 45.3953, 44.4640, 49.7442, 45.0101, 44.4287,
         45.0638, 44.7219],
        [45.2202, 44.2766, 44.1868, 44.2146, 44.5890, 45.0101, 51.2595, 44.0701,
         44.8024, 44.3152],
        [44.2731, 44.8679, 45.2062, 44.9964, 44.9677, 44.4287, 44.07

step0, f_mu[0]:  tensor([-1.1404, -0.6112,  0.0638, -5.7022, -3.3961,  0.4619, -0.3255, -1.1643,
        -0.4621, -2.7602], device='cuda:0')
step0, f_var[0]:  tensor([[20.8192,  9.4801, 10.0187,  9.6485,  9.8105,  9.9640, 10.9504,  9.7242,
         10.2646, 10.0056],
        [ 9.4801, 19.7250, 10.6544,  9.8593, 10.4861,  9.6741,  9.7344, 10.5051,
         10.4746, 10.0924],
        [10.0187, 10.6544, 17.8370, 10.7472, 10.1583,  9.8700,  9.6033, 10.9963,
         10.6501, 10.1502],
        [ 9.6485,  9.8593, 10.7472, 17.6534, 10.0050, 11.2682,  9.6280, 10.6993,
         10.6622, 10.5145],
        [ 9.8105, 10.4861, 10.1583, 10.0050, 17.0493,  9.9551, 10.1548, 10.6470,
         10.4229, 11.9967],
        [ 9.9640,  9.6741,  9.8700, 11.2682,  9.9551, 18.2020, 10.7077,  9.9067,
         10.8122, 10.3257],
        [10.9504,  9.7344,  9.6033,  9.6280, 10.1548, 10.7077, 20.2588,  9.4459,
         10.4405,  9.7617],
        [ 9.7242, 10.5051, 10.9963, 10.6993, 10.6470,  9.9067,  9.4459, 17.295

step0, f_mu[0]:  tensor([ -0.8369,  -2.9291,   1.3599,  -1.4025, -10.6595,   0.7165,  -3.9299,
         -1.4033,   1.3914,  -4.9146], device='cuda:0')
step0, f_var[0]:  tensor([[19.4675, 15.5034, 15.7128, 15.5741, 15.6382, 15.7011, 16.0947, 15.5999,
         15.8102, 15.7085],
        [15.5034, 19.0334, 15.9662, 15.6640, 15.8997, 15.5863, 15.6016, 15.9105,
         15.8891, 15.7561],
        [15.7128, 15.9662, 18.3190, 15.9940, 15.7778, 15.6681, 15.5549, 16.0868,
         15.9517, 15.7790],
        [15.5741, 15.6640, 15.9940, 18.2565, 15.7229, 16.1852, 15.5695, 15.9774,
         15.9595, 15.9072],
        [15.6382, 15.8997, 15.7778, 15.7229, 18.0419, 15.6998, 15.7647, 15.9624,
         15.8728, 16.4303],
        [15.7011, 15.5863, 15.6681, 16.1853, 15.6998, 18.4572, 15.9846, 15.6813,
         16.0124, 15.8342],
        [16.0947, 15.6016, 15.5549, 15.5695, 15.7648, 15.9846, 19.2479, 15.4941,
         15.8761, 15.6220],
        [15.5999, 15.9105, 16.0868, 15.9774, 15.9624, 15.6813, 15.49

step0, f_mu[0]:  tensor([ -6.0668,  -2.5466,   1.1892,  -3.9820,  -3.0327,  -4.5590, -11.1181,
          3.4039,   3.6676, -11.0982], device='cuda:0')
step0, f_var[0]:  tensor([[57.8597, 43.0514, 43.8040, 43.3005, 43.5295, 43.7533, 45.1586, 43.3965,
         44.1529, 43.7882],
        [43.0514, 56.3075, 44.7089, 43.6162, 44.4715, 43.3428, 43.4049, 44.5069,
         44.4396, 43.9449],
        [43.8040, 44.7089, 53.7204, 44.8176, 44.0274, 43.6303, 43.2334, 45.1555,
         44.6710, 44.0262],
        [43.3005, 43.6162, 44.8176, 53.4860, 43.8247, 45.5162, 43.2807, 44.7558,
         44.6951, 44.5020],
        [43.5295, 44.4715, 44.0274, 43.8247, 52.6910, 43.7456, 43.9923, 44.6953,
         44.3764, 46.4414],
        [43.7533, 43.3428, 43.6303, 45.5162, 43.7456, 54.2201, 44.7774, 43.6789,
         44.8911, 44.2391],
        [45.1586, 43.4049, 43.2334, 43.2807, 43.9923, 44.7774, 57.0720, 43.0144,
         44.3925, 43.4683],
        [43.3965, 44.5069, 45.1555, 44.7558, 44.6953, 43.6789, 43.01

step0, f_mu[0]:  tensor([-3.3718, -4.2952,  4.6727, -0.3328, -1.6820, -7.8870, -4.4332, -2.2254,
         0.3621, -5.0100], device='cuda:0')
step0, f_var[0]:  tensor([[34.7160, 24.4937, 25.0190, 24.6669, 24.8269, 24.9820, 25.9613, 24.7345,
         25.2616, 25.0072],
        [24.4937, 33.6340, 25.6496, 24.8861, 25.4835, 24.6957, 24.7407, 25.5080,
         25.4614, 25.1164],
        [25.0190, 25.6496, 31.8270, 25.7256, 25.1735, 24.8969, 24.6200, 25.9612,
         25.6226, 25.1737],
        [24.6669, 24.8861, 25.7256, 31.6638, 25.0326, 26.2140, 24.6529, 25.6823,
         25.6401, 25.5049],
        [24.8269, 25.4835, 25.1735, 25.0326, 31.1115, 24.9774, 25.1497, 25.6413,
         25.4180, 26.8549],
        [24.9820, 24.6957, 24.8969, 26.2140, 24.9774, 32.1767, 25.6976, 24.9308,
         25.7767, 25.3215],
        [25.9613, 24.7407, 24.6200, 24.6529, 25.1497, 25.6976, 34.1665, 24.4670,
         25.4288, 24.7844],
        [24.7345, 25.5080, 25.9612, 25.6823, 25.6413, 24.9308, 24.4670, 31.332

step0, f_mu[0]:  tensor([ -2.8484,   3.3361,  -0.6029, -10.5454,   4.8237,  -1.3438,  -0.5975,
         -0.3477,  -1.2627,  -6.8201], device='cuda:0')
step0, f_var[0]:  tensor([[34.2250, 21.7356, 22.3676, 21.9429, 22.1350, 22.3213, 23.4978, 22.0249,
         22.6595, 22.3537],
        [21.7356, 32.9248, 23.1253, 22.2051, 22.9259, 21.9773, 22.0328, 22.9545,
         22.9009, 22.4812],
        [22.3676, 23.1253, 30.7456, 23.2189, 22.5513, 22.2175, 21.8869, 23.5037,
         23.0967, 22.5498],
        [21.9429, 22.2051, 23.2189, 30.5465, 22.3801, 23.8093, 21.9252, 23.1664,
         23.1165, 22.9526],
        [22.1350, 22.9259, 22.5513, 22.3801, 29.8748, 22.3147, 22.5256, 23.1148,
         22.8472, 24.5939],
        [22.3213, 21.9773, 22.2175, 23.8093, 22.3147, 31.1669, 23.1836, 22.2586,
         23.2824, 22.7317],
        [23.4978, 22.0328, 21.8869, 21.9252, 22.5256, 23.1836, 33.5642, 21.7028,
         22.8615, 22.0829],
        [22.0249, 22.9545, 23.5037, 23.1664, 23.1148, 22.2586, 21.70

step0, f_mu[0]:  tensor([-4.4973, -6.2037, -3.3750,  4.2487, -5.1242,  0.5486, -8.6510, -0.8557,
        -2.2925,  0.8212], device='cuda:0')
step0, f_var[0]:  tensor([[32.7729, 25.4034, 25.7790, 25.5271, 25.6414, 25.7524, 26.4524, 25.5755,
         25.9526, 25.7708],
        [25.4034, 31.9994, 26.2298, 25.6837, 26.1112, 25.5477, 25.5800, 26.1286,
         26.0958, 25.8479],
        [25.7790, 26.2298, 30.7064, 26.2847, 25.8892, 25.6911, 25.4937, 26.4535,
         26.2116, 25.8888],
        [25.5271, 25.6837, 26.2847, 30.5890, 25.7879, 26.6344, 25.5169, 26.2537,
         26.2236, 26.1267],
        [25.6414, 26.1112, 25.8892, 25.7879, 30.1919, 25.7487, 25.8728, 26.2236,
         26.0643, 27.0968],
        [25.7524, 25.5477, 25.6911, 26.6344, 25.7487, 30.9564, 26.2642, 25.7154,
         26.3217, 25.9956],
        [26.4524, 25.5800, 25.4937, 25.5169, 25.8728, 26.2642, 32.3801, 25.3843,
         26.0724, 25.6107],
        [25.5755, 26.1286, 26.4535, 26.2536, 26.2236, 25.7154, 25.3843, 30.351

step0, f_mu[0]:  tensor([ -5.7261,  -0.1351,   6.5100,   0.2276,  -3.6600, -12.8357,  -5.8143,
         -0.4559,   2.9008, -10.1335], device='cuda:0')
step0, f_var[0]:  tensor([[46.8736, 34.9512, 35.5618, 35.1539, 35.3398, 35.5215, 36.6626, 35.2314,
         35.8448, 35.5489],
        [34.9512, 45.6135, 36.2965, 35.4108, 36.1037, 35.1882, 35.2380, 36.1328,
         36.0769, 35.6779],
        [35.5618, 36.2965, 43.5171, 36.3835, 35.7441, 35.4225, 35.0991, 36.6570,
         36.2638, 35.7441],
        [35.1539, 35.4108, 36.3835, 43.3283, 35.5805, 36.9489, 35.1382, 36.3337,
         36.2840, 36.1279],
        [35.3398, 36.1037, 35.7441, 35.5805, 42.6870, 35.5158, 35.7142, 36.2857,
         36.0264, 37.6926],
        [35.5215, 35.1882, 35.4225, 36.9489, 35.5158, 43.9222, 36.3518, 35.4617,
         36.4421, 35.9147],
        [36.6626, 35.2380, 35.0991, 35.1382, 35.7142, 36.3518, 46.2344, 34.9215,
         36.0389, 35.2906],
        [35.2314, 36.1328, 36.6570, 36.3337, 36.2857, 35.4617, 34.92

step0, f_mu[0]:  tensor([-7.4551, -1.8317, -5.6707, -4.0425,  3.0625,  3.0362, -6.1932, -4.9783,
        -2.5884, -2.1859], device='cuda:0')
step0, f_var[0]:  tensor([[47.2520, 31.7689, 32.5486, 32.0234, 32.2601, 32.4891, 33.9378, 32.1254,
         32.9082, 32.5312],
        [31.7689, 45.6502, 33.4817, 32.3448, 33.2359, 32.0651, 32.1357, 33.2704,
         33.2065, 32.6853],
        [32.5486, 33.4817, 42.9573, 33.5992, 32.7726, 32.3603, 31.9547, 33.9513,
         33.4493, 32.7699],
        [32.0234, 32.3448, 33.5991, 42.7096, 32.5604, 34.3302, 32.0007, 33.5339,
         33.4731, 33.2696],
        [32.2601, 33.2359, 32.7726, 32.5604, 41.8768, 32.4805, 32.7436, 33.4692,
         33.1396, 35.3062],
        [32.4891, 32.0651, 32.3603, 34.3302, 32.4805, 43.4779, 33.5541, 32.4112,
         33.6790, 32.9972],
        [33.9378, 32.1357, 31.9547, 32.0007, 32.7436, 33.5542, 46.4372, 31.7275,
         33.1579, 32.1953],
        [32.1254, 33.2704, 33.9513, 33.5339, 33.4692, 32.4112, 31.7275, 42.212

In [39]:
# compute average log-likelihood for LB KFAC
print(torch.distributions.Categorical(torch.tensor(MNIST_test_in_LB_KFAC)).log_prob(torch.tensor(targets)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_LB_KFAC)).log_prob(torch.tensor(targets_FMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_LB_KFAC)).log_prob(torch.tensor(targets_notMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_LB_KFAC)).log_prob(torch.tensor(targets_KMNIST)).mean())

tensor(-0.0342)
tensor(-3.8378)
tensor(-6.0142)
tensor(-6.0235)


In [40]:
print(scoring.expected_calibration_error(targets, MNIST_test_in_LB_KFAC))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_LB_KFAC))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_LB_KFAC))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_LB_KFAC))

0.011948898989899039
0.5424864242424243
0.6367586083004797
0.6168500707070707


In [41]:
acc_in_LB_KFAC, prob_correct_in_LB_KFAC, ent_in_LB_KFAC, MMC_in_LB_KFAC = get_in_dist_values(MNIST_test_in_LB_KFAC, targets)
acc_out_FMNIST_LB_KFAC, prob_correct_out_FMNIST_LB_KFAC, ent_out_FMNIST_LB_KFAC, MMC_out_FMNIST_LB_KFAC, auroc_out_FMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_FMNIST_LB_KFAC, targets_FMNIST)
acc_out_notMNIST_LB_KFAC, prob_correct_out_notMNIST_LB_KFAC, ent_out_notMNIST_LB_KFAC, MMC_out_notMNIST_LB_KFAC, auroc_out_notMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_notMNIST_LB_KFAC, targets_notMNIST)
acc_out_KMNIST_LB_KFAC, prob_correct_out_KMNIST_LB_KFAC, ent_out_KMNIST_LB_KFAC, MMC_out_KMNIST_LB_KFAC, auroc_out_KMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_KMNIST_LB_KFAC, targets_KMNIST)

In [42]:
print_in_dist_values(acc_in_LB_KFAC, prob_correct_in_LB_KFAC, ent_in_LB_KFAC, MMC_in_LB_KFAC, 'MNIST', 'LB_KFAC')
print_out_dist_values(acc_out_FMNIST_LB_KFAC, prob_correct_out_FMNIST_LB_KFAC, ent_out_FMNIST_LB_KFAC, MMC_out_FMNIST_LB_KFAC, auroc_out_FMNIST_LB_KFAC, 'MNIST', test='fmnist', method='LB_KFAC')
print_out_dist_values(acc_out_notMNIST_LB_KFAC, prob_correct_out_notMNIST_LB_KFAC, ent_out_notMNIST_LB_KFAC, MMC_out_notMNIST_LB_KFAC, auroc_out_notMNIST_LB_KFAC, 'MNIST', test='notMNIST', method='LB_KFAC')
print_out_dist_values(acc_out_KMNIST_LB_KFAC, prob_correct_out_KMNIST_LB_KFAC, ent_out_KMNIST_LB_KFAC, MMC_out_KMNIST_LB_KFAC, auroc_out_KMNIST_LB_KFAC, 'MNIST', test='KMNIST', method='LB_KFAC')

[In, LB_KFAC, MNIST] Accuracy: 0.988; average entropy: 0.051;     MMC: 0.985; Prob @ correct: 0.100
[Out-fmnist, LB_KFAC, MNIST] Accuracy: 0.075; Average entropy: 1.202;    MMC: 0.617; AUROC: 0.978; Prob @ correct: 0.100
[Out-notMNIST, LB_KFAC, MNIST] Accuracy: 0.147; Average entropy: 0.685;    MMC: 0.757; AUROC: 0.912; Prob @ correct: 0.100
[Out-KMNIST, LB_KFAC, MNIST] Accuracy: 0.094; Average entropy: 0.806;    MMC: 0.711; AUROC: 0.957; Prob @ correct: 0.100


In [43]:
break

SyntaxError: 'break' outside loop (668683560.py, line 1)

# Compare to extended MacKay approach

In [None]:
MNIST_test_in_EMK_D = predict_extended_MacKay(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_EMK_D = predict_extended_MacKay(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_EMK_D = predict_extended_MacKay(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_EMK_D = predict_extended_MacKay(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
acc_in_EMK, prob_correct_in_EMK, ent_in_EMK, MMC_in_EMK = get_in_dist_values(MNIST_test_in_EMK_D, targets)
acc_out_FMNIST_EMK, prob_correct_out_FMNIST_EMK, ent_out_FMNIST_EMK, MMC_out_FMNIST_EMK, auroc_out_FMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_D, MNIST_test_out_FMNIST_EMK_D, targets_FMNIST)
acc_out_notMNIST_EMK, prob_correct_out_notMNIST_EMK, ent_out_notMNIST_EMK, MMC_out_notMNIST_EMK, auroc_out_notMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_D, MNIST_test_out_notMNIST_EMK_D, targets_notMNIST)
acc_out_KMNIST_EMK, prob_correct_out_KMNIST_EMK, ent_out_KMNIST_EMK, MMC_out_KMNIST_EMK, auroc_out_KMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_D, MNIST_test_out_KMNIST_EMK_D, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_EMK, prob_correct_in_EMK, ent_in_EMK, MMC_in_EMK, 'MNIST', 'EMK')
print_out_dist_values(acc_out_FMNIST_EMK, prob_correct_out_FMNIST_EMK, ent_out_FMNIST_EMK, MMC_out_FMNIST_EMK, auroc_out_FMNIST_EMK, 'MNIST', test='fmnist', method='EMK')
print_out_dist_values(acc_out_notMNIST_EMK, prob_correct_out_notMNIST_EMK, ent_out_notMNIST_EMK, MMC_out_notMNIST_EMK, auroc_out_notMNIST_EMK, 'MNIST', test='notMNIST', method='EMK')
print_out_dist_values(acc_out_KMNIST_EMK, prob_correct_out_KMNIST_EMK, ent_out_KMNIST_EMK, MMC_out_KMNIST_EMK, auroc_out_KMNIST_EMK, 'MNIST', test='KMNIST', method='EMK')

### EMK KFAC

In [None]:
MNIST_test_in_EMK_K = predict_extended_MacKay(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_EMK_K = predict_extended_MacKay(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_EMK_K = predict_extended_MacKay(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_EMK_K = predict_extended_MacKay(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
acc_in_EMK, prob_correct_in_EMK, ent_in_EMK, MMC_in_EMK = get_in_dist_values(MNIST_test_in_EMK_K, targets)
acc_out_FMNIST_EMK, prob_correct_out_FMNIST_EMK, ent_out_FMNIST_EMK, MMC_out_FMNIST_EMK, auroc_out_FMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_K, MNIST_test_out_FMNIST_EMK_K, targets_FMNIST)
acc_out_notMNIST_EMK, prob_correct_out_notMNIST_EMK, ent_out_notMNIST_EMK, MMC_out_notMNIST_EMK, auroc_out_notMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_K, MNIST_test_out_notMNIST_EMK_K, targets_notMNIST)
acc_out_KMNIST_EMK, prob_correct_out_KMNIST_EMK, ent_out_KMNIST_EMK, MMC_out_KMNIST_EMK, auroc_out_KMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_K, MNIST_test_out_KMNIST_EMK_K, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_EMK, prob_correct_in_EMK, ent_in_EMK, MMC_in_EMK, 'MNIST', 'EMK')
print_out_dist_values(acc_out_FMNIST_EMK, prob_correct_out_FMNIST_EMK, ent_out_FMNIST_EMK, MMC_out_FMNIST_EMK, auroc_out_FMNIST_EMK, 'MNIST', test='fmnist', method='EMK')
print_out_dist_values(acc_out_notMNIST_EMK, prob_correct_out_notMNIST_EMK, ent_out_notMNIST_EMK, MMC_out_notMNIST_EMK, auroc_out_notMNIST_EMK, 'MNIST', test='notMNIST', method='EMK')
print_out_dist_values(acc_out_KMNIST_EMK, prob_correct_out_KMNIST_EMK, ent_out_KMNIST_EMK, MMC_out_KMNIST_EMK, auroc_out_KMNIST_EMK, 'MNIST', test='KMNIST', method='EMK')

# Compare to Second-order Delta Posterior Predictive

In [None]:
MNIST_test_in_SODPP_D = predict_second_order_dpp(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_SODPP_D = predict_second_order_dpp(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_SODPP_D = predict_second_order_dpp(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_SODPP_D = predict_second_order_dpp(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP = get_in_dist_values(MNIST_test_in_SODPP_D, targets)
acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_FMNIST_SODPP_D, targets_FMNIST)
acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_notMNIST_SODPP_D, targets_notMNIST)
acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_KMNIST_SODPP_D, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP, 'MNIST', 'SODPP')
print_out_dist_values(acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP, 'MNIST', test='fmnist', method='SODPP')
print_out_dist_values(acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP, 'MNIST', test='notMNIST', method='SODPP')
print_out_dist_values(acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP, 'MNIST', test='KMNIST', method='SODPP')

### KFAC SODPP

In [None]:
MNIST_test_in_SODPP_K = predict_second_order_dpp(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_SODPP_K = predict_second_order_dpp(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_SODPP_K = predict_second_order_dpp(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_SODPP_K = predict_second_order_dpp(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP = get_in_dist_values(MNIST_test_in_SODPP_K, targets)
acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_FMNIST_SODPP_K, targets_FMNIST)
acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_notMNIST_SODPP_K, targets_notMNIST)
acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_KMNIST_SODPP_K, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP, 'MNIST', 'SODPP')
print_out_dist_values(acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP, 'MNIST', test='fmnist', method='SODPP')
print_out_dist_values(acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP, 'MNIST', test='notMNIST', method='SODPP')
print_out_dist_values(acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP, 'MNIST', test='KMNIST', method='SODPP')

In [None]:
break

# Experiments on Rotated MNIST

In [None]:
import torchvision.transforms as transforms
# rotate 15 degrees
MNIST_transform_r15 = transforms.Compose([
    transforms.RandomRotation(15),
    transforms.ToTensor()
    ])

MNIST_test_r15 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r15)

mnist_test_loader_r15 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r15,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 30 degrees
MNIST_transform_r30 = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    ])

MNIST_test_r30 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r30)

mnist_test_loader_r30 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r30,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 45 degrees
MNIST_transform_r45 = transforms.Compose([
    transforms.RandomRotation(45),
    transforms.ToTensor(),
    ])

MNIST_test_r45 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r45)

mnist_test_loader_r45 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r45,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 60 degrees
MNIST_transform_r60 = transforms.Compose([
    transforms.RandomRotation(60),
    transforms.ToTensor(),
    ])

MNIST_test_r60 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r60)

mnist_test_loader_r60 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r60,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 75 degrees
MNIST_transform_r75 = transforms.Compose([
    transforms.RandomRotation(75),
    transforms.ToTensor(),
    ])

MNIST_test_r75 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r75)

mnist_test_loader_r75 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r75,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 90 degrees
MNIST_transform_r90 = transforms.Compose([
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    ])

MNIST_test_r90 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r90)

mnist_test_loader_r90 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r90,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 105 degrees
MNIST_transform_r105 = transforms.Compose([
    transforms.RandomRotation(105),
    transforms.ToTensor(),
    ])

MNIST_test_r105 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r105)

mnist_test_loader_r105 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r105,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 120 degrees
MNIST_transform_r120 = transforms.Compose([
    transforms.RandomRotation(120),
    transforms.ToTensor(),
    ])

MNIST_test_r120 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r120)

mnist_test_loader_r120 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r120,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 135 degrees
MNIST_transform_r135 = transforms.Compose([
    transforms.RandomRotation(135),
    transforms.ToTensor(),
    ])

MNIST_test_r135 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r135)

mnist_test_loader_r135 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r135,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 150 degrees
MNIST_transform_r150= transforms.Compose([
    transforms.RandomRotation(150),
    transforms.ToTensor(),
    ])

MNIST_test_r150 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r150)

mnist_test_loader_r150 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r150,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 165 degrees
MNIST_transform_r165 = transforms.Compose([
    transforms.RandomRotation(165),
    transforms.ToTensor(),
    ])

MNIST_test_r165 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r165)

mnist_test_loader_r165 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r165,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 180 degrees
MNIST_transform_r180 = transforms.Compose([
    transforms.RandomRotation(180),
    transforms.ToTensor(),
    ])

MNIST_test_r180 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r180)

mnist_test_loader_r180 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r180,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

In [None]:
from sklearn.metrics import brier_score_loss

## helper function: given a dataloader compute accuracy and brier score
def get_acc_brier(dataloader, targets, num_samples=1000):
    
    # compute sampling results
    mnist_rotated_D = predict_diagonal_sampling(mnist_model, dataloader, M_W_post_D, M_b_post_D, C_W_post_D, C_b_post_D, verbose=False, cuda=True, timing=False, n_samples=num_samples).cpu().numpy()
    
    # compute LB results
    mnist_rotated_LB = predict_LB(mnist_model, dataloader, M_W_post_D, M_b_post_D, C_W_post_D, C_b_post_D, verbose=False, cuda=True, timing=False).cpu().numpy()
    mnist_rotated_LBn = mnist_rotated_LB/mnist_rotated_LB.sum(1).reshape(-1,1)
    
    # accuracy for sampling and LB
    acc_D = np.mean(np.argmax(mnist_rotated_D, 1) == targets)
    acc_LB = np.mean(np.argmax(mnist_rotated_LBn, 1) == targets)
    
    # get brier score for sampling and LB
    pred_at_true_D = np.array([mnist_rotated_D[i, j] for i, j in enumerate(targets)])
    pred_at_true_LBn = np.array([mnist_rotated_LBn[i, j] for i, j in enumerate(targets)])
    
    brier_D = brier_score_loss(np.ones_like(pred_at_true_D), pred_at_true_D)
    brier_LB = brier_score_loss(np.ones_like(pred_at_true_LBn), pred_at_true_LBn)
    
    return(acc_D, acc_LB, brier_D, brier_LB)

In [None]:
# just a test
get_acc_brier(mnist_test_loader_r15, targets)

In [None]:
## predict on all distributions and compute accuracy and brier score

dataloader_list = [mnist_test_loader, mnist_test_loader_r15, mnist_test_loader_r30, mnist_test_loader_r45,
                  mnist_test_loader_r60, mnist_test_loader_r75, mnist_test_loader_r90, mnist_test_loader_r105,
                  mnist_test_loader_r120, mnist_test_loader_r135, mnist_test_loader_r150, mnist_test_loader_r165, 
                  mnist_test_loader_r180]

Acc_D_list = []
Acc_LB_list = []
Brier_D_list = []
Brier_LB_list = []

for i, loader_ in enumerate(dataloader_list):
    print(i)
    Acc_D_, Acc_LB_, Brier_D_, Brier_LB_ = get_acc_brier(loader_, targets)
    Acc_D_list.append(Acc_D_)
    Acc_LB_list.append(Acc_LB_)
    Brier_D_list.append(Brier_D_)
    Brier_LB_list.append(Brier_LB_)
    

In [None]:
# Make inline plots vector graphics
import matplotlib
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("pdf", "svg")

matplotlib.rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{amsfonts} \usepackage{amsmath}"

In [None]:
print(Acc_D_list)
print(Acc_LB_list)
print(Brier_D_list)
print(Brier_LB_list)

In [None]:
#### compare over 5 seeds: 123, 124, 125, 126, 127
Acc_D_all = np.array([
    [0.9906, 0.981, 0.9401, 0.8593, 0.7388, 0.6375, 0.5627, 0.5049, 0.4703, 0.4533, 0.4363, 0.4287, 0.4268],
    [0.9889, 0.9814, 0.9433, 0.8518, 0.7337, 0.6403, 0.561, 0.5016, 0.4669, 0.4306, 0.4292, 0.4104, 0.4183],
    [0.9904, 0.9829, 0.946, 0.8582, 0.747, 0.6495, 0.5602, 0.5075, 0.4684, 0.4437, 0.4334, 0.4245, 0.4217],
    [0.9875, 0.9776, 0.9353, 0.8422, 0.7212, 0.6377, 0.55, 0.4896, 0.4497, 0.4328, 0.416, 0.4117, 0.4095],
    [0.9894, 0.9816, 0.9451, 0.8522, 0.7403, 0.6391, 0.5565, 0.4994, 0.4656, 0.4308, 0.4284, 0.4221, 0.4152]
])

Acc_LB_all = np.array([
    [0.9907, 0.983, 0.9394, 0.8526, 0.7419, 0.6386, 0.5703, 0.5152, 0.4702, 0.4484, 0.4335, 0.4236, 0.4182],
    [0.9889, 0.9804, 0.9471, 0.8547, 0.7411, 0.6296, 0.5624, 0.4938, 0.4609, 0.4404, 0.4295, 0.4172, 0.4152],
    [0.9905, 0.9831, 0.9436, 0.8616, 0.7461, 0.6517, 0.5647, 0.5127, 0.4673, 0.4454, 0.4273, 0.4291, 0.415],
    [0.9876, 0.9799, 0.9398, 0.851, 0.7352, 0.6368, 0.5445, 0.4992, 0.4632, 0.4269, 0.42, 0.417, 0.4181],
    [0.9893, 0.9808, 0.939, 0.8465, 0.7338, 0.6312, 0.5573, 0.505, 0.4539, 0.4458, 0.4158, 0.4189, 0.4145]
])

Brier_D_all = np.array([
    [0.008377700125184932, 0.016529366393931315, 0.05263197122933036, 0.12622978433992207, 0.23428155265648049, 0.33001102154197504, 0.40364649390174306, 0.45833621095510935, 0.48925685879783176, 0.5084873524892697, 0.5265005890140535, 0.5351897862160594, 0.535655688821819],
    [0.010744372283432188, 0.01861786528128727, 0.04921392785665237, 0.130553247122029, 0.23708335123859933, 0.3264322724129021, 0.4020700676185214, 0.45858313992431765, 0.49138653020376494, 0.5267105075863001, 0.5302083038117827, 0.5471840382483096, 0.5406172102829678],
    [0.008842406659404037, 0.01589756693721347, 0.04813633596723046, 0.12450638545589118, 0.2251983306557004, 0.31596819055291797, 0.39958471677594976, 0.4511044612864316, 0.4908296869743894, 0.513979687459884, 0.5228012975599322, 0.5331324584577392, 0.5366587666968312],
    [0.010730850229531219, 0.019535984307980356, 0.054532256239393395, 0.13633893796280153, 0.24614911889753252, 0.32741910714203704, 0.4095507809925442, 0.4699407662727923, 0.5086184210935191, 0.525650285971151, 0.544528220493102, 0.5479489062984142, 0.5516986011718078],
    [0.009795412763015064, 0.01817015340220003, 0.04894315094823046, 0.12908701202114847, 0.23282846945381916, 0.3285278649762006, 0.4081318056690933, 0.46182442301821763, 0.4973149243671581, 0.5269344762781906, 0.5322562363481955, 0.5380640556117866, 0.5471252220414157]
])

Brier_LB_all = np.array([
    [0.007844242911062461, 0.015166313465830961, 0.05119071505966292, 0.13220723722716274, 0.23246172960926015, 0.3326357106190896, 0.4017439850542811, 0.45393964088759164, 0.49749242841385766, 0.5211752780472316, 0.5353101111950346, 0.5431750248692432, 0.5509519585301749],
    [0.009505463873917114, 0.016971207786010672, 0.0470778754224552, 0.12817579122361247, 0.2349776937700805, 0.339070595213455, 0.40568261118347, 0.4740733457035303, 0.5030623144832805, 0.5253332179429556, 0.5360191577762338, 0.549535112754868, 0.5511099660875561],
    [0.0076984171595750415, 0.014756499162356471, 0.0469422462627717, 0.12101585517824154, 0.22767872259093588, 0.31762503643888157, 0.4011728288118217, 0.4516875142101383, 0.4954514924412558, 0.5178717196068767, 0.5371807295796656, 0.5349908080424384, 0.5497019253514567],
    [0.01001831719153492, 0.016729483129191477, 0.05172484471829214, 0.13075484635707604, 0.23891760292945455, 0.3334325462508124, 0.42297503706889894, 0.4695227387634461, 0.5033610244309235, 0.5387111053297119, 0.5473753173717428, 0.5506346709903308, 0.5503556560492305],
    [0.008755419081030589, 0.01618864623953837, 0.05156303842716825, 0.13417919609729662, 0.23963138466766964, 0.34131713744443276, 0.4110757904444741, 0.46371903749768173, 0.5107108175950072, 0.5222953217647085, 0.5486314035009504, 0.5470332878374046, 0.5531479907833964]   
])

In [None]:
Acc_D_all_mean = np.mean(Acc_D_all, axis=0)
Acc_D_all_std = np.std(Acc_D_all, axis=0)

Acc_LB_all_mean = np.mean(Acc_LB_all, axis=0)
Acc_LB_all_std = np.std(Acc_LB_all, axis=0)

Brier_D_all_mean = np.mean(Brier_D_all, axis=0)
Brier_D_all_std = np.std(Brier_D_all, axis=0)

Brier_LB_all_mean = np.mean(Brier_LB_all, axis=0)
Brier_LB_all_std = np.std(Brier_LB_all, axis=0)

In [None]:
# make a figure for accuracy
x_labels = ['test', '15$^{\circ}$', '30$^{\circ}$', '45$^{\circ}$', '60$^{\circ}$', '75$^{\circ}$', '90$^{\circ}$',
            '105$^{\circ}$', '120$^{\circ}$', '135$^{\circ}$', '150$^{\circ}$', '165$^{\circ}$', '180$^{\circ}$']

plt.figure(figsize=(5, 1.5))
plt.plot(x_labels, Acc_D_all_mean, color='blue', label='MCMC')
plt.plot(x_labels, Acc_LB_all_mean, color='firebrick', label='LB')
plt.ylim(0,1)
plt.ylabel('Accuracy', size=15)
#plt.legend()
plt.xticks(x_labels, x_labels, rotation='30', size=13)

plt.tight_layout()

plt.savefig('figures/LB_vs_MCMC_Acc.pdf')
plt.show();

In [None]:
# make a figure for the brier score

plt.figure(figsize=(5, 1.5))
plt.plot(x_labels, Brier_D_list, color='blue', label='MCMC')
plt.plot(x_labels, Brier_LB_list, color='firebrick', label='LB')
plt.ylim(0,0.8)
plt.ylabel('Brier', size=15)
plt.legend()
plt.xticks(x_labels, x_labels, rotation='30', size=13)

plt.tight_layout()

plt.savefig('figures/LB_vs_MCMC_Brier.pdf')
plt.show();

In [None]:
# make a subplot

fig, ax = plt.subplots(2, 1, figsize=(5, 3))

ax[0].plot(x_labels, Acc_D_all_mean, color='blue', label='MCMC')
ax[0].plot(x_labels, Acc_LB_all_mean, color='firebrick', label='LB')
ax[0].set_ylim(0,1)
ax[0].set_ylabel('Accuracy', size=15)
ax[0].set_xticklabels([])

ax[1].plot(x_labels, Brier_D_list, color='blue', label='MCMC')
ax[1].plot(x_labels, Brier_LB_list, color='firebrick', label='LB')
ax[1].set_ylim(0,0.8)
ax[1].set_ylabel('Brier', size=15)
ax[1].legend()
ax[1].set_xticklabels(x_labels, rotation=30, size=13)


plt.tight_layout()

plt.savefig('figures/LB_vs_MCMC_Acc_Brier.pdf')
plt.show();