# Train an RBF network via gradient descent

In this notebook, we show how to instantiate and train an RBF network (and an MLP network). We will test the OOD capabilities of the trained deterministic discriminative models by looking at the softmax entropy and some chosen OOD datasets.

In [1]:
import os
import sys

curr_dir = os.path.basename(os.path.abspath(os.curdir))
# See __init__.py in folder "toy_example" for an explanation.
if curr_dir == 'tutorials' and '..' not in sys.path:
    sys.path.insert(0, '..')

from hypnettorch.data.mnist_data import MNISTData
from hypnettorch.data.fashion_mnist import FashionMNISTData
from hypnettorch.mnets import MLP
from hypnettorch.utils import misc
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score
from time import time
import torch
from torch import nn

from finite_width.rbf_net import StackedRBFNet

from IPython.display import display, Markdown, Latex
#display(Markdown('*some markdown* $\phi$'))

%matplotlib inline
%load_ext autoreload
%autoreload 2

device = 'cuda'

In [2]:
mnist = MNISTData('.', use_one_hot=True)
fmnist = FashionMNISTData('.', use_one_hot=True)

Reading MNIST dataset ...
Elapsed time to read dataset: 0.165256 sec


In [3]:
def test_net(net, data, use_test=True):
    with torch.no_grad():
        test_in = data.input_to_torch_tensor( \
            data.get_test_inputs() if use_test else data.get_val_inputs(), \
            device, mode='inference')
        test_out = data.input_to_torch_tensor( \
            data.get_test_outputs() if use_test else data.get_val_outputs(),
            device, mode='inference')
        test_lbls = test_out.max(dim=1)[1]

        logits = net(test_in)
        pred_lbls = logits.max(dim=1)[1]

        acc = torch.sum(test_lbls == pred_lbls) / test_lbls.numel() * 100.
    return acc

def train_net(net, data, lr=1e-3, nepochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.internal_params, lr=lr)

    for epoch in range(nepochs): 

        i = 0
        for batch_size, x, y in data.train_iterator(32):
            i += 1

            x_t = data.input_to_torch_tensor(x, device, mode='train')
            y_t = data.output_to_torch_tensor(y, device, mode='train')

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            p_t = net(x_t)
            loss = criterion(p_t, y_t.max(dim=1)[1])
            loss.backward()
            optimizer.step()

            if i % 500 == 0:            
                print('[%d, %5d] loss: %.3f, val-acc: %.2f%%' %
                      (epoch + 1, i + 1, loss.item(), 
                       test_net(net, data, use_test=False)))

    print('Training finished with test-acc: %.2f%%' % (test_net(net, mnist)))

In [4]:
rbf_net = StackedRBFNet(n_in=np.prod(mnist.in_shape), n_nonlin_units=(100,), 
                        n_lin_units=(10,), use_bias=True,
                        bandwidth=5000).to(device)

train_net(rbf_net, mnist, lr=1e-2, nepochs=30)

Creating a "1-layer RBF network" with 79410 weights
[1,   501] loss: 2.587, val-acc: 11.00%
[1,  1001] loss: 2.187, val-acc: 28.56%
[1,  1501] loss: 2.159, val-acc: 24.38%
[2,   501] loss: 1.987, val-acc: 39.70%
[2,  1001] loss: 1.888, val-acc: 47.42%
[2,  1501] loss: 1.718, val-acc: 39.80%
[3,   501] loss: 1.456, val-acc: 51.36%
[3,  1001] loss: 1.399, val-acc: 53.96%
[3,  1501] loss: 1.538, val-acc: 51.54%
[4,   501] loss: 1.544, val-acc: 57.52%
[4,  1001] loss: 1.398, val-acc: 59.42%
[4,  1501] loss: 1.339, val-acc: 63.46%
[5,   501] loss: 1.316, val-acc: 48.32%
[5,  1001] loss: 1.395, val-acc: 58.98%
[5,  1501] loss: 1.275, val-acc: 68.08%
[6,   501] loss: 1.446, val-acc: 65.60%
[6,  1001] loss: 1.090, val-acc: 58.60%
[6,  1501] loss: 0.989, val-acc: 71.60%
[7,   501] loss: 1.137, val-acc: 68.28%
[7,  1001] loss: 1.019, val-acc: 72.46%
[7,  1501] loss: 1.086, val-acc: 67.06%
[8,   501] loss: 1.128, val-acc: 71.94%
[8,  1001] loss: 1.096, val-acc: 72.88%
[8,  1501] loss: 0.954, val-

In [5]:
mlp_net = MLP(n_in=np.prod(mnist.in_shape), n_out=10,
              hidden_layers=(400,400)).to(device)

train_net(mlp_net, mnist, lr=1e-3)

Creating an MLP with 478410 weights.
[1,   501] loss: 0.027, val-acc: 94.70%
[1,  1001] loss: 0.169, val-acc: 95.18%
[1,  1501] loss: 0.046, val-acc: 96.66%
[2,   501] loss: 0.017, val-acc: 96.38%
[2,  1001] loss: 0.204, val-acc: 96.76%
[2,  1501] loss: 0.110, val-acc: 97.58%
[3,   501] loss: 0.132, val-acc: 97.70%
[3,  1001] loss: 0.155, val-acc: 97.64%
[3,  1501] loss: 0.002, val-acc: 97.40%
[4,   501] loss: 0.120, val-acc: 98.08%
[4,  1001] loss: 0.009, val-acc: 97.66%
[4,  1501] loss: 0.002, val-acc: 97.50%
[5,   501] loss: 0.001, val-acc: 97.28%
[5,  1001] loss: 0.002, val-acc: 97.74%
[5,  1501] loss: 0.049, val-acc: 97.74%
[6,   501] loss: 0.001, val-acc: 98.12%
[6,  1001] loss: 0.004, val-acc: 97.16%
[6,  1501] loss: 0.016, val-acc: 97.84%
[7,   501] loss: 0.117, val-acc: 98.16%
[7,  1001] loss: 0.070, val-acc: 97.64%
[7,  1501] loss: 0.009, val-acc: 97.84%
[8,   501] loss: 0.004, val-acc: 98.08%
[8,  1001] loss: 0.023, val-acc: 98.06%
[8,  1501] loss: 0.000, val-acc: 97.94%
[9,

In [6]:
def calc_auroc(net, ind_data, ood_data):
    with torch.no_grad():
        ind_inps = ind_data.input_to_torch_tensor( \
            ind_data.get_test_inputs(), device, mode='inference')
        ind_logits = net(ind_inps)
        ind_softmax = nn.functional.softmax(ind_logits, dim=1).\
            cpu().detach().numpy()
        ind_entropies = - np.sum(ind_softmax * \
                                 np.log(np.maximum(ind_softmax, 1e-5)), axis=1)
        
        ood_inps = ood_data.input_to_torch_tensor( \
            ood_data.get_test_inputs(), device, mode='inference')
        ood_logits = net(ood_inps)
        ood_softmax = nn.functional.softmax(ood_logits, dim=1).\
            cpu().detach().numpy()
        ood_entropies = - np.sum(ood_softmax * \
                                 np.log(np.maximum(ood_softmax, 1e-5)), axis=1)
        
        y_true = [0]*len(ind_entropies) + [1]*len(ood_entropies)
        y_score = ind_entropies.tolist() + ood_entropies.tolist()
        auroc = roc_auc_score(y_true, y_score)
        
        return auroc

print('MLP AUROC: %.3f' % (calc_auroc(mlp_net, mnist, fmnist)))
print('RBF Net AUROC: %.3f' % (calc_auroc(rbf_net, mnist, fmnist)))

MLP AUROC: 0.876
RBF Net AUROC: 0.792
