In [1]:
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from classificationutils.resnet import ResNet50

torch.set_default_dtype(torch.float64)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"\n Using {device} device")
print(f"CUDA version: {torch.version.cuda}")


 Using cpu device
CUDA version: None


In [2]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614)),
])

training_data = datasets.SVHN(
    root="data/SVHN",
    split='train',
    download=True,
    transform=transform_train
)

test_data = datasets.SVHN(
    root="data/SVHN",
    split='test',
    download=True,
    transform=transform_test
)

transform_test_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

ood_test_data = datasets.CIFAR10(
    root="data/CIFAR10",
    train=False,
    download=True,
    transform=transform_test_cifar
) 

n_output = 10
n_channels = 3

In [3]:
net = ResNet50(in_channels=n_channels, num_classes = n_output)
net.load_state_dict(torch.load('data/resnet50_trained_svhn.pt', weights_only=True, map_location=torch.device(device)))


<All keys matched successfully>

In [None]:
from nuqls.posterior import Nuqls

nuqls_posterior = Nuqls(net, task='classification')
loss,acc = nuqls_posterior.train(train=training_data, 
                    train_bs=10, 
                    n_output=n_output,
                    S=10,
                    scale=0.0025, 
                    lr=1e-2, 
                    epochs=2, 
                    mu=0.9,
                    verbose=True,
                    extra_verbose=True)

id_logits = nuqls_posterior.test(test_data, test_bs=152) 
id_predictions = id_logits.softmax(dim=2)
id_variance = id_predictions.var(0)

ood_logits = nuqls_posterior.test(ood_test_data, test_bs=152)
ood_predictions = ood_logits.softmax(dim=2)
ood_variance = ood_predictions.var(0)

NUQLS is using device cpu.


  0%|          | 1/7326 [00:27<55:52:36, 27.46s/it, min_loss_ma=0.0349, max_loss_batch=0.267, min_acc_batch=0.9, max_acc_batch=1, resid_norm=4.2e-5, gpu_mem=0]
  0%|          | 0/2 [00:27<?, ?it/s]


KeyboardInterrupt: 

In [None]:
nuqls_variance = classificationutils.metrics.sort_probabilies(id_predictions.to('cpu'), ood_predictions.to('cpu'), test_data=test_data)
nuqls_variance = classificationutils.metrics.add_baseline(nuqls_variance,test_data,ood_test_data)

classificationutils.metrics.plot_vmsp(prob_dict=nuqls_variance,
                          title=f'SVHN ResNet50',
                          save_fig=f"examples/images/vmsp_plot.pdf")