In [None]:
import argparse
import json
import pathlib
import random

import torch
import torch.nn.functional as F
import torch.utils.data

!pip install pytorch-ignite
from ignite.engine import Events, Engine
from ignite.metrics import Accuracy, Average, Loss
from ignite.contrib.handlers import ProgressBar

from utils.resnet_duq import ResNet_DUQ
from utils.datasets import all_datasets
from utils.evaluate_ood import get_cifar_svhn_ood, get_auroc_classification

model={}
results=[]
def main(
    batch_size,
    epochs,
    length_scale,
    centroid_size,
    model_output_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
):

    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (
            test_dataset.transform
        )  # Test time preprocessing for validation
    
    global model
    model = ResNet_DUQ(
        input_size, num_classes, centroid_size, model_output_size, length_scale, gamma
    )
    
    model = model.cuda()
    #model.load_state_dict(torch.load("DUQ_CIFAR_75.pt"))

    optimizer = torch.optim.SGD(
        model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay
    )

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[25, 50, 75], gamma=0.2
    )

    def bce_loss_fn(y_pred, y):
        bce = F.binary_cross_entropy(y_pred, y, reduction="sum").div(
            num_classes * y_pred.shape[0]
        )
        return bce

    def output_transform_bce(output):
        y_pred, y, x = output

        y = F.one_hot(y, num_classes).float()

        return y_pred, y

    def output_transform_acc(output):
        y_pred, y, x = output

        return y_pred, y

    def output_transform_gp(output):
        y_pred, y, x = output

        return x, y_pred

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1) ** 2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        if l_gradient_penalty > 0:
            x.requires_grad_(True)

        z, y_pred = model(x)
        y = F.one_hot(y, num_classes).float()

        loss = bce_loss_fn(y_pred, y)

        if l_gradient_penalty > 0:
            loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred)

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        z, y_pred = model(x)

        return y_pred, y, x

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    kwargs = {"num_workers": 4, "pin_memory": True}
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs
    )


    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):

        if trainer.state.epoch % 10 == 0 or trainer.state.epoch > 70:
            #testing of cifar test set and auroc on cifar+svhn
            accuracy, auroc = get_cifar_svhn_ood(model)
            ar=auroc
            print(f"Test Accuracy: {accuracy}, AUROC: {auroc}")
            acc, auroc = get_auroc_classification(val_dataset, model)
            print(f"AUROC - uncertainty: {auroc}")
            global results
            results.append({'epoch':trainer.state.epoch,'Test accuracy':accuracy,'Ood/roc_auc':ar,'val_acc':acc,'auroc-uncertainity':auroc})
       

        scheduler.step()

        if trainer.state.epoch > 70:
            torch.save(
                model.state_dict(), f"saved_models/{trainer.state.epoch}.pt"
            )

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    trainer.run(train_loader, max_epochs=epochs)


if __name__ == "__main__":
    main(128,75,0.1,512,512,0.05,0,0.999,5e-4,True)
    

Collecting pytorch-ignite
[?25l  Downloading https://files.pythonhosted.org/packages/14/98/0a5b83d82ff245d3de5f09808fb80ff0ed03f6b10933979e6018b1dd0eaa/pytorch_ignite-0.4.2-py2.py3-none-any.whl (175kB)
[K     |█▉                              | 10kB 16.5MB/s eta 0:00:01[K     |███▊                            | 20kB 2.9MB/s eta 0:00:01[K     |█████▋                          | 30kB 3.6MB/s eta 0:00:01[K     |███████▌                        | 40kB 3.9MB/s eta 0:00:01[K     |█████████▍                      | 51kB 3.4MB/s eta 0:00:01[K     |███████████▏                    | 61kB 3.7MB/s eta 0:00:01[K     |█████████████                   | 71kB 4.2MB/s eta 0:00:01[K     |███████████████                 | 81kB 4.4MB/s eta 0:00:01[K     |████████████████▉               | 92kB 4.6MB/s eta 0:00:01[K     |██████████████████▊             | 102kB 4.4MB/s eta 0:00:01[K     |████████████████████▌           | 112kB 4.4MB/s eta 0:00:01[K     |██████████████████████▍         | 12

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/CIFAR10/cifar-10-python.tar.gz to ./data/CIFAR10
Files already downloaded and verified


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data/SVHN/train_32x32.mat


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./data/SVHN/test_32x32.mat


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))







Test Accuracy: 0.7873, AUROC: 0.8532417543792256
AUROC - uncertainty: 0.8273596518210369


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.8551, AUROC: 0.8223703941303012
AUROC - uncertainty: 0.8699435748598572


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.9245, AUROC: 0.9021758585586969
AUROC - uncertainty: 0.8964874516026202


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.9109, AUROC: 0.8884765538567917
AUROC - uncertainty: 0.8860511409455677


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.9183, AUROC: 0.9011189631991395
AUROC - uncertainty: 0.880859488243336


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.9429, AUROC: 0.879922247618316
AUROC - uncertainty: 0.8972922156353716


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.943, AUROC: 0.8845017862630609
AUROC - uncertainty: 0.8942571300999052


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.9437, AUROC: 0.8805954748002459
AUROC - uncertainty: 0.8964849819246302


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.943, AUROC: 0.8867428318992009
AUROC - uncertainty: 0.9031084072854458


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat
Test Accuracy: 0.9419, AUROC: 0.8871043427320221
AUROC - uncertainty: 0.9039587284572747


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=390.0), HTML(value='')), layout=Layout(di…

Buffered data was truncated after reaching the output size limit.

In [None]:
torch.save(model.state_dict(), "DUQ_CIFAR_75.pt")
print(results)

[{'epoch': 10, 'Test accuracy': 0.7873, 'Ood/roc_auc': 0.8532417543792256, 'val_acc': 0.7873, 'auroc-uncertainity': 0.8273596518210369}, {'epoch': 20, 'Test accuracy': 0.8551, 'Ood/roc_auc': 0.8223703941303012, 'val_acc': 0.8551, 'auroc-uncertainity': 0.8699435748598572}, {'epoch': 30, 'Test accuracy': 0.9245, 'Ood/roc_auc': 0.9021758585586969, 'val_acc': 0.9245, 'auroc-uncertainity': 0.8964874516026202}, {'epoch': 40, 'Test accuracy': 0.9109, 'Ood/roc_auc': 0.8884765538567917, 'val_acc': 0.9109, 'auroc-uncertainity': 0.8860511409455677}, {'epoch': 50, 'Test accuracy': 0.9183, 'Ood/roc_auc': 0.9011189631991395, 'val_acc': 0.9183, 'auroc-uncertainity': 0.880859488243336}, {'epoch': 60, 'Test accuracy': 0.9429, 'Ood/roc_auc': 0.879922247618316, 'val_acc': 0.9429, 'auroc-uncertainity': 0.8972922156353716}, {'epoch': 66, 'Test accuracy': 0.943, 'Ood/roc_auc': 0.8845017862630609, 'val_acc': 0.943, 'auroc-uncertainity': 0.8942571300999052}, {'epoch': 67, 'Test accuracy': 0.9437, 'Ood/roc_auc