In [None]:
!mkdir -p data && cd data && curl -O "http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat"

!pip install pytorch-ignite

#const

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  112M  100  112M    0     0  11.9M      0  0:00:09  0:00:09 --:--:-- 15.9M


In [None]:
import random
import numpy as np

import torch
import torch.utils.data
from torch.nn import functional as F

from ignite.engine import Events, Engine
from ignite.metrics import Accuracy, Loss

from ignite.contrib.handlers.tqdm_logger import ProgressBar

from utils.evaluate_ood import (
    get_fashionmnist_mnist_ood,
    get_fashionmnist_notmnist_ood,
)
from utils.datasets import FastFashionMNIST, get_FashionMNIST
from utils.cnn_duq import CNN_DUQ


In [None]:
model={}
def train_model(l_gradient_penalty, length_scale, final_model,epochs,cn):

    input_size = 28
    num_classes = 10
    embedding_size = 256
    learnable_length_scale = False
    gamma = 0.999


    ## Main (FashionMNIST) and ood (Mnist) Dataset
    dataset = FastFashionMNIST("data/", train=True, download=True)
    test_dataset = FastFashionMNIST("data/", train=False, download=True)

    idx = list(range(60000))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        train_dataset = torch.utils.data.Subset(dataset, indices=idx[:55000])
        val_dataset = torch.utils.data.Subset(dataset, indices=idx[55000:])

    dl_train = torch.utils.data.DataLoader(
        train_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True
    )

    dl_val = torch.utils.data.DataLoader(
        val_dataset, batch_size=2000, shuffle=False, num_workers=0
    )

    dl_test = torch.utils.data.DataLoader(
        test_dataset, batch_size=2000, shuffle=False, num_workers=0
    )


    # Model
    global model
    model = CNN_DUQ(
        input_size,
        num_classes,
        embedding_size,
        learnable_length_scale,
        length_scale,
        gamma,
    )
    
    model = model.cuda()
    #model.load_state_dict(torch.load("DUQ_FM_30_FULL.pt"))

    # Optimiser
    optimizer = torch.optim.SGD(
        model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4
    )

    def output_transform_bce(output):
        y_pred, y, _, _ = output
        return y_pred, y

    def output_transform_acc(output):
        y_pred, y, _, _ = output
        return y_pred, torch.argmax(y, dim=1)

    def output_transform_gp(output):
        y_pred, y, x, y_pred_sum = output
        return x, y_pred_sum

    def calc_gradient_penalty(x, y_pred_sum):
        gradients = torch.autograd.grad(
            outputs=y_pred_sum,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred_sum),
            create_graph=True,
            retain_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - cn) ** 2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        y = F.one_hot(y, num_classes=10).float()

        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        z, y_pred = model(x)

        loss = F.binary_cross_entropy(y_pred, y)
        loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred.sum(1))

        x.requires_grad_(False)

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        y = F.one_hot(y, num_classes=10).float()

        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        z, y_pred = model(x)

        return y_pred, y, x, y_pred.sum(1)

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Accuracy(output_transform=output_transform_acc)
    metric.attach(evaluator, "accuracy")

    metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce)
    metric.attach(evaluator, "bce")

    metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)
    metric.attach(evaluator, "gradient_penalty")

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[10, 20], gamma=0.2
    )

    pbar = ProgressBar()
    pbar.attach(trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        scheduler.step()

        # logging every 5 epoch
        if trainer.state.epoch % 5 == 0:
            evaluator.run(dl_val)

            # AUROC on FashionMNIST + Mnist / NotMnist
            accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)
            accuracy, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)
            metrics = evaluator.state.metrics

            print(
                f"Validation Results - Epoch: {trainer.state.epoch} "
                f"Val_Acc: {metrics['accuracy']:.4f} "
                f"BCE: {metrics['bce']:.2f} "
                f"GP: {metrics['gradient_penalty']:.6f} "
                f"AUROC MNIST: {roc_auc_mnist:.4f} "
                f"AUROC NotMNIST: {roc_auc_notmnist:.2f} "
            )
            print(f"Sigma: {model.sigma}")

    # Train
    trainer.run(dl_train, max_epochs=epochs)

    # Validation
    evaluator.run(dl_val)
    val_accuracy = evaluator.state.metrics["accuracy"]

    # Test
    evaluator.run(dl_test)
    test_accuracy = evaluator.state.metrics["accuracy"]

    return model, val_accuracy, test_accuracy

In [None]:
if __name__ == "__main__":
    _, _, _, fashionmnist_test_dataset = get_FashionMNIST()
    const=[0.01,0.05,0.1,0.5,1,3]
    l_gradient_penalties = [0.05,1]
    length_scale = 0.1
    epochs=30

    repetition = 1  # Increase for multiple repetitions
    final_model = True  # set true for final model to train on full train set

    results = {}

    for l_gradient_penalty in l_gradient_penalties:
        for cn in const:
            val_accuracies = []
            test_accuracies = []
            roc_aucs_mnist = []
            roc_aucs_notmnist = []

            for _ in range(repetition):
                print(f" ### NEW MODEL ### gp={l_gradient_penalty}, constant = {cn}")
                model, val_accuracy, test_accuracy = train_model(
                    l_gradient_penalty, length_scale, final_model, epochs,cn
                )
                accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)
                _, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)

                val_accuracies.append(val_accuracy)
                test_accuracies.append(test_accuracy)
                roc_aucs_mnist.append(roc_auc_mnist)
                roc_aucs_notmnist.append(roc_auc_notmnist)
            
            # All stats
            results[f"lgp{l_gradient_penalty}_ls{length_scale}_gpconst{cn}"] = [
                ("val acc", np.mean(val_accuracies)),
                ("test acc", np.mean(test_accuracies)),
                ("M auroc", np.mean(roc_aucs_mnist)),
                ("NM auroc", np.mean(roc_aucs_notmnist)),
            ]
            #print(results[f"lgp{l_gradient_penalty}_ls{length_scale}"])
    
    # Save
    torch.save(model.state_dict(), "DUQ_FM_30_FULL.pt")
    print(results)


 ### NEW MODEL ### gp=0.05, constant = 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 5 Val_Acc: 0.9047 BCE: 0.06 GP: 0.024281 AUROC MNIST: 0.9472 AUROC NotMNIST: 0.96 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 10 Val_Acc: 0.9051 BCE: 0.05 GP: 0.023031 AUROC MNIST: 0.9580 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 15 Val_Acc: 0.9201 BCE: 0.04 GP: 0.034128 AUROC MNIST: 0.9460 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 20 Val_Acc: 0.9258 BCE: 0.04 GP: 0.035562 AUROC MNIST: 0.9527 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 25 Val_Acc: 0.9235 BCE: 0.04 GP: 0.042006 AUROC MNIST: 0.9557 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 30 Val_Acc: 0.9250 BCE: 0.04 GP: 0.046738 AUROC MNIST: 0.9586 AUROC NotMNIST: 0.96 
Sigma: 0.1
 ### NEW MODEL ### gp=0.05, constant = 0.5


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 5 Val_Acc: 0.9096 BCE: 0.06 GP: 0.046325 AUROC MNIST: 0.9053 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 10 Val_Acc: 0.9167 BCE: 0.05 GP: 0.051560 AUROC MNIST: 0.9335 AUROC NotMNIST: 0.92 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 15 Val_Acc: 0.9251 BCE: 0.04 GP: 0.061618 AUROC MNIST: 0.9440 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 20 Val_Acc: 0.9261 BCE: 0.04 GP: 0.064240 AUROC MNIST: 0.9517 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 25 Val_Acc: 0.9272 BCE: 0.04 GP: 0.071237 AUROC MNIST: 0.9538 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 30 Val_Acc: 0.9262 BCE: 0.04 GP: 0.076361 AUROC MNIST: 0.9574 AUROC NotMNIST: 0.95 
Sigma: 0.1
 ### NEW MODEL ### gp=0.05, constant = 1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 5 Val_Acc: 0.8928 BCE: 0.06 GP: 0.078873 AUROC MNIST: 0.9550 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 10 Val_Acc: 0.9147 BCE: 0.05 GP: 0.069315 AUROC MNIST: 0.9157 AUROC NotMNIST: 0.93 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 15 Val_Acc: 0.9207 BCE: 0.05 GP: 0.050877 AUROC MNIST: 0.9440 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 20 Val_Acc: 0.9215 BCE: 0.05 GP: 0.047834 AUROC MNIST: 0.9533 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 25 Val_Acc: 0.9211 BCE: 0.04 GP: 0.052089 AUROC MNIST: 0.9422 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 30 Val_Acc: 0.9213 BCE: 0.04 GP: 0.058031 AUROC MNIST: 0.9452 AUROC NotMNIST: 0.94 
Sigma: 0.1
 ### NEW MODEL ### gp=0.05, constant = 2


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 5 Val_Acc: 0.8787 BCE: 0.08 GP: 0.564870 AUROC MNIST: 0.8430 AUROC NotMNIST: 0.92 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 10 Val_Acc: 0.8970 BCE: 0.06 GP: 0.778026 AUROC MNIST: 0.8382 AUROC NotMNIST: 0.92 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 15 Val_Acc: 0.9164 BCE: 0.05 GP: 0.699279 AUROC MNIST: 0.9320 AUROC NotMNIST: 0.93 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 20 Val_Acc: 0.9171 BCE: 0.05 GP: 1.083244 AUROC MNIST: 0.9278 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 25 Val_Acc: 0.9176 BCE: 0.05 GP: 0.973148 AUROC MNIST: 0.9442 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 30 Val_Acc: 0.9177 BCE: 0.05 GP: 1.043002 AUROC MNIST: 0.9436 AUROC NotMNIST: 0.94 
Sigma: 0.1
 ### NEW MODEL ### gp=0.05, constant = 10


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 5 Val_Acc: 0.1008 BCE: 9.99 GP: 100.000008 AUROC MNIST: 0.4954 AUROC NotMNIST: 0.50 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 10 Val_Acc: 0.1000 BCE: 9.37 GP: 100.000008 AUROC MNIST: 0.5249 AUROC NotMNIST: 0.22 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 15 Val_Acc: 0.1115 BCE: 3.82 GP: 54.039555 AUROC MNIST: 0.6300 AUROC NotMNIST: 0.67 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 20 Val_Acc: 0.1000 BCE: 0.33 GP: 100.000008 AUROC MNIST: 0.5000 AUROC NotMNIST: 0.50 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 25 Val_Acc: 0.1000 BCE: 0.33 GP: 100.000008 AUROC MNIST: 0.5000 AUROC NotMNIST: 0.50 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 30 Val_Acc: 0.1000 BCE: 0.33 GP: 100.000008 AUROC MNIST: 0.5000 AUROC NotMNIST: 0.50 
Sigma: 0.1
 ### NEW MODEL ### gp=1, constant = 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 5 Val_Acc: 0.9011 BCE: 0.06 GP: 0.002114 AUROC MNIST: 0.9303 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 10 Val_Acc: 0.9102 BCE: 0.05 GP: 0.002376 AUROC MNIST: 0.9413 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 15 Val_Acc: 0.9198 BCE: 0.04 GP: 0.002815 AUROC MNIST: 0.9409 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 20 Val_Acc: 0.9227 BCE: 0.04 GP: 0.003549 AUROC MNIST: 0.9372 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 25 Val_Acc: 0.9223 BCE: 0.04 GP: 0.003420 AUROC MNIST: 0.9451 AUROC NotMNIST: 0.95 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 30 Val_Acc: 0.9234 BCE: 0.04 GP: 0.004105 AUROC MNIST: 0.9440 AUROC NotMNIST: 0.95 
Sigma: 0.1
 ### NEW MODEL ### gp=1, constant = 0.5


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 5 Val_Acc: 0.9037 BCE: 0.06 GP: 0.039189 AUROC MNIST: 0.9286 AUROC NotMNIST: 0.93 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 10 Val_Acc: 0.9160 BCE: 0.05 GP: 0.050697 AUROC MNIST: 0.9271 AUROC NotMNIST: 0.93 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 15 Val_Acc: 0.9215 BCE: 0.04 GP: 0.060481 AUROC MNIST: 0.9541 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 20 Val_Acc: 0.9217 BCE: 0.04 GP: 0.062513 AUROC MNIST: 0.9438 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 25 Val_Acc: 0.9219 BCE: 0.04 GP: 0.067642 AUROC MNIST: 0.9498 AUROC NotMNIST: 0.94 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 30 Val_Acc: 0.9229 BCE: 0.04 GP: 0.065219 AUROC MNIST: 0.9517 AUROC NotMNIST: 0.94 
Sigma: 0.1
 ### NEW MODEL ### gp=1, constant = 1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 5 Val_Acc: 0.8723 BCE: 0.08 GP: 0.150696 AUROC MNIST: 0.9208 AUROC NotMNIST: 0.88 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 10 Val_Acc: 0.8940 BCE: 0.07 GP: 0.271609 AUROC MNIST: 0.8768 AUROC NotMNIST: 0.84 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 15 Val_Acc: 0.9062 BCE: 0.06 GP: 0.302272 AUROC MNIST: 0.9065 AUROC NotMNIST: 0.89 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))

Validation Results - Epoch: 20 Val_Acc: 0.8988 BCE: 0.06 GP: 0.352902 AUROC MNIST: 0.8965 AUROC NotMNIST: 0.86 
Sigma: 0.1


HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))