In [69]:
from torchvision.datasets import SVHN, CIFAR10
import torch
import torchvision.transforms as transforms
from neural_mean_discrepancy import Neural_Mean_Discrepancy
from utils import get_conv_layer_names
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report

## Neural Mean Discrepancy demo

#### 1. Retrieve datasets. 

We will use CIFAR10 as the i.d. dataset, and SVHN as the o.o.d. dataset. 

In [46]:
# id dataset.
CIFAR10_transforms = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
CIFAR10_train_dataset = CIFAR10(root='data/cifar10', download=True, train=True, transform=CIFAR10_transforms)
CIFAR10_test_dataset = CIFAR10(root='data/cifar10', download=True, train=False, transform=CIFAR10_transforms)
CIFAR10_val_dataset, CIFAR10_test_dataset = torch.utils.data.random_split(CIFAR10_test_dataset, [5000, 5000])

# ood dataset
SVHN_transforms = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))])
SVHN_test_dataset = SVHN(root='data/svhn', download=True, split='test', transform=SVHN_transforms)
SVHN_val_dataset, SVHN_test_dataset = torch.utils.data.random_split(SVHN_test_dataset, [13016, 13016])

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: data/svhn/test_32x32.mat


#### 2. Load a ResNet20 model that has been pretrained on CIFAR10

In [13]:
# https://github.com/chenyaofo/pytorch-cifar-models?tab=readme-ov-file
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
model.eval()

Using cache found in /home/kitbransby/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

#### 3. Fit the i.d. train dataset to calculate the neural mean feature vector

In [14]:
# set device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('connected to device: {}'.format(device))

connected to device: cuda


In [15]:
# Retrieve the names of the convolutional layers in the ResNet model
layer_names = get_conv_layer_names(model)
print(layer_names)

['conv1', 'layer1.0.conv1', 'layer1.0.conv2', 'layer1.1.conv1', 'layer1.1.conv2', 'layer1.2.conv1', 'layer1.2.conv2', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.0.downsample.0', 'layer2.1.conv1', 'layer2.1.conv2', 'layer2.2.conv1', 'layer2.2.conv2', 'layer3.0.conv1', 'layer3.0.conv2', 'layer3.0.downsample.0', 'layer3.1.conv1', 'layer3.1.conv2', 'layer3.2.conv1', 'layer3.2.conv2']


In [24]:
# instantiate model
nmd_model = Neural_Mean_Discrepancy(model=model, layer_names=layer_names, device=device).to(device)

# fit i.d train set to calculate the neural mean feature of the dataset. 
nmd_model.fit_in_distribution_dataset(CIFAR10_train_dataset)

In [70]:
# the i.d nmf of the train set
nmd_model.nmf.shape

torch.Size([784])

#### 4. Compute nmd vectors for i.d. and o.o.d. validation sets. 

In [50]:
cifar10_nmd_vector, cifar10_nmd_per_sample = nmd_model.predict_nmd_unk_distribtion(CIFAR10_val_dataset)
print(cifar10_nmd_vector.shape, cifar10_nmd_per_sample.shape)

Predicting nmd of unknown distribution dataset..


100%|██████████| 5000/5000 [00:37<00:00, 133.75it/s]


torch.Size([784]) torch.Size([5000, 784])


In [51]:
svhn_nmd_vector, svhn_nmd_per_sample = nmd_model.predict_nmd_unk_distribtion(SVHN_val_dataset)
print(svhn_nmd_vector.shape, svhn_nmd_per_sample.shape)

Predicting nmd of unknown distribution dataset..


100%|██████████| 13016/13016 [01:34<00:00, 137.87it/s]


torch.Size([784]) torch.Size([13016, 784])


In [54]:
# mean nmd for i.d. test set is much smaller than for o.o.d.
cifar10_nmd_vector.mean(), svhn_nmd_vector.mean()

(tensor(-2.9543e-06, device='cuda:0'), tensor(0.0079, device='cuda:0'))

#### 5. Train LogisticRegression o.o.d detector model using neural mean discrepancy vectors as samples 

In [60]:
# i.d samples are negative (label = 0), o.o.d samples are positive (label = 1)
val_examples = np.concatenate([cifar10_nmd_per_sample.cpu().numpy(), 
                               svhn_nmd_per_sample.cpu().numpy()], axis=0)
val_labels = np.concatenate([np.zeros(cifar10_nmd_per_sample.shape[0]), 
                             np.ones(svhn_nmd_per_sample.shape[0])], axis=0)
print(val_examples.shape, val_labels.shape)

(18016, 784) (18016,)


In [61]:
# fit LR
lr = LogisticRegression(max_iter=500)
lr.fit(val_examples, val_labels)

#### 6. Evaluate o.o.d. detector on test set

In [62]:
# create nmd vectors for the i.d and o.o.d test set
_, cifar10_nmd_test = nmd_model.predict_nmd_unk_distribtion(CIFAR10_test_dataset)
print(cifar10_nmd_test.shape)
_, svhn_nmd_test = nmd_model.predict_nmd_unk_distribtion(SVHN_test_dataset)
print(svhn_nmd_test.shape)

Predicting nmd of unknown distribution dataset..


100%|██████████| 5000/5000 [00:40<00:00, 121.96it/s]


torch.Size([5000, 784])
Predicting nmd of unknown distribution dataset..


100%|██████████| 13016/13016 [01:43<00:00, 125.44it/s]


torch.Size([13016, 784])


In [63]:
# i.d samples are negative (label = 0), o.o.d samples are positive (label = 1)
test_examples = np.concatenate([cifar10_nmd_test.cpu().numpy(), 
                               svhn_nmd_test.cpu().numpy()], axis=0)
test_labels = np.concatenate([np.zeros(cifar10_nmd_test.shape[0]), 
                             np.ones(svhn_nmd_test.shape[0])], axis=0)
print(test_examples.shape, test_labels.shape)

(18016, 784) (18016,)


In [68]:
test_predictions = lr.predict(test_examples)
clf_report = classification_report(test_labels, test_predictions, 
                                   labels=[0,1], target_names=['id', 'ood'], digits=4, output_dict=False, zero_division='warn')
print(clf_report)

              precision    recall  f1-score   support

          id     0.9938    0.9910    0.9924      5000
         ood     0.9965    0.9976    0.9971     13016

    accuracy                         0.9958     18016
   macro avg     0.9952    0.9943    0.9947     18016
weighted avg     0.9958    0.9958    0.9958     18016
