In [1]:
from tensorflow import keras
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 6

from sklearn.neighbors import KernelDensity
from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm

from scipy import stats

import sys 
sys.path.append("../")
from KDG import KDG

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Using TensorFlow backend.


# Experiment Parameters

In [12]:
num_trials = 50

# Data Generation

# Construct & Train Network

In [14]:
def construct_network(X, y):
    network = keras.Sequential()

    network.add(keras.layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu', input_shape=np.shape(X)[1:]))
    network.add(keras.layers.AveragePooling2D())
    
    network.add(keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=np.shape(X)[1:]))
    network.add(keras.layers.AveragePooling2D())

    network.add(keras.layers.Flatten())

    network.add(keras.layers.Dense(units=16, activation='relu'))

    network.add(keras.layers.Dense(units=16, activation='relu'))
    
    network.add(keras.layers.Dense(units=16, activation='relu'))

    network.add(keras.layers.Dense(units=len(np.unique(y)), activation = 'softmax'))
    network.compile(loss = 'categorical_crossentropy', optimizer = keras.optimizers.Adam(3e-3))
    network.fit(
      X, 
      keras.utils.to_categorical(y), 
      epochs = 60, 
      verbose = False,
      batch_size = int(2 ** (np.log(len(X)) / np.log(5) + 2.2))
    )

    return network

    

## Keurnal Networks

In [15]:
def train_AKN(X, y):
    X_transform, X, y_transform, y = train_test_split(X, y, test_size = 0.5)
    #X_transform, y_transform = X, y
    network = construct_network(X_transform, y_transform)
    
    encoder = keras.models.Model(network.inputs, network.layers[5].output)
    X = encoder.predict(X)
    polytope_memberships = []
    last_activations = X
    for layer_id in range(6, len(network.layers)):
        weights, bias = network.layers[layer_id].get_weights()
        preactivation = np.matmul(last_activations, weights) + bias
        if layer_id == len(network.layers) - 1:
            binary_preactivation = (preactivation > 0.5).astype('int')
        else:
            binary_preactivation = (preactivation > 0).astype('int')
        polytope_memberships.append(binary_preactivation)
        last_activations = preactivation * binary_preactivation

    polytope_memberships = np.concatenate(polytope_memberships, axis = 1)
    _, polytope_ids = np.unique(np.matmul(polytope_memberships, 2 ** np.arange(0, np.shape(polytope_memberships)[1])), return_inverse = True)


    kdg = KDG().fit(X, y, polytope_ids)
    return kdg, encoder

In [16]:
def clipped_mean(ra, low = 25, high = 75):
    ra = np.array(ra)
    lower_val = np.nanpercentile(ra, low)
    higher_val = np.nanpercentile(ra, high)
    return np.mean(ra[np.where((ra >= lower_val) & (ra <= higher_val))[0]])

In [17]:
def get_ece(predicted_posterior, y):
    hists = []
    hists_hat = []
    amts = []
    num_bins = 10
    eces_across_y_vals = []
    for y_val in np.unique(y):
        for i in range(num_bins):
            prop = i*1./num_bins
            inds = np.where((predicted_posterior[:, y_val] >= prop) & (predicted_posterior[:, y_val] <= prop+1./num_bins))[0]
            amts.append(len(inds))
            if len(inds) > 0:
                hists.append(len(np.where(y[inds] == y_val)[0])*1./len(inds))
                hists_hat.append(np.mean(predicted_posterior[inds, y_val]))
            else:
                hists.append(prop)
                hists_hat.append(prop + 0.5/num_bins)
        eces_across_y_vals.append(np.dot(np.abs(np.array(hists) - np.array(hists_hat)), amts) / np.sum(amts))
        return np.mean(eces_across_y_vals)

In [18]:
def get_brier(predicted_posterior, y):
    brier_across_y_vals = []
    for y_val in np.unique(y):
        brier_across_y_vals.append(np.nanmean((predicted_posterior[:, y_val] - (y == y_val).astype('int'))**2))
    return np.mean(brier_across_y_vals)

In [19]:

(X, y), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
train_indices, test_indices = np.where(y < 5)[0], np.where(y_test < 5)[0]
X, y = np.expand_dims(X[train_indices], axis = -1), y[train_indices]
X_test, y_test = np.expand_dims(X_test[test_indices], axis = -1), y_test[test_indices]

test_indices = np.random.choice(range(len(X_test)), 1000)
X_test, y_test = X_test[test_indices], y_test[test_indices]

In [20]:
'''
(X, y), (X_test, y_test) = keras.datasets.cifar10.load_data()
train_indices, test_indices = np.where(y < 2)[0], np.where(y_test < 2)[0]
X, y = X[train_indices], y[train_indices, 0]
X_test, y_test = X_test[test_indices], y_test[test_indices, 0]

test_indices = np.random.choice(range(len(X_test)), 1000)
X_test, y_test = X_test[test_indices], y_test[test_indices]
'''

'\n(X, y), (X_test, y_test) = keras.datasets.cifar10.load_data()\ntrain_indices, test_indices = np.where(y < 2)[0], np.where(y_test < 2)[0]\nX, y = X[train_indices], y[train_indices, 0]\nX_test, y_test = X_test[test_indices], y_test[test_indices, 0]\n\ntest_indices = np.random.choice(range(len(X_test)), 1000)\nX_test, y_test = X_test[test_indices], y_test[test_indices]\n'

In [21]:
def get_network_y_proba(X, y, n, X_test):
    random_indices = np.random.choice(len(X), int(n))
    X, y = X[random_indices], y[random_indices]
    network = construct_network(X, y)
    return network.predict_proba(X_test)
    
    

def get_KDE_y_proba(X, y, n, X_test):
    random_indices = np.random.choice(len(X), int(n))
    kdg, encoder = train_AKN(X, y)
    X, y = encoder.predict(X[random_indices]), y[random_indices]
    X_test = encoder.predict(X_test)
    y_proba_test = kdg.predict_proba(X_test)
    y_proba_test_normalized = np.copy(y_proba_test)
    for y_val in range(np.shape(y_proba_test)[1]):
        y_proba_test_normalized[:, y_val] /= np.sum(y_proba_test, axis = 1)
    return y_proba_test_normalized

In [None]:
KDE_acc_means = []
network_acc_means = []

KDE_acc_stds = []
network_acc_stds = []

KDE_ece_means = []
network_ece_means = []

KDE_ece_stds = []
network_ece_stds = []
n_ra = np.logspace(2.5, 7, num = 10, base = 10)
ticks = np.arange(np.min(n_ra), np.max(n_ra), step = int((np.max(n_ra) - np.min(n_ra)) // 4))
ticks_ra = np.array([int(str(tick)[:1]) * 10 ** int(np.log10(tick)) for tick in ticks])
for n in tqdm(n_ra):
    KDE_y_test_proba_across_trials = np.array([get_KDE_y_proba(X, y, n, X_test) for _ in tqdm(range(num_trials))])
    network_y_test_proba_across_trials = np.array([get_network_y_proba(X, y, n, X_test) for _ in tqdm(range(num_trials))])

    
    KDE_accs_across_trials = []
    network_accs_across_trials = []
    KDE_eces_across_trials = []
    network_eces_across_trials = []
    for trial_idx in range(num_trials):
        KDE_accs_across_trials.append(np.nanmean(np.argmax(KDE_y_test_proba_across_trials[trial_idx], axis = 1) == y_test))
        network_accs_across_trials.append(np.nanmean(np.argmax(network_y_test_proba_across_trials[trial_idx], axis = 1) == y_test))
        
        KDE_eces_across_trials.append(get_ece(KDE_y_test_proba_across_trials[trial_idx], y_test))
        network_eces_across_trials.append(get_ece(network_y_test_proba_across_trials[trial_idx], y_test))
        
    KDE_acc_means.append(clipped_mean(KDE_accs_across_trials, 50, 100))
    network_acc_means.append(clipped_mean(network_accs_across_trials, 50, 100))
    
    KDE_ece_means.append(clipped_mean(KDE_eces_across_trials, 0, 50))
    network_ece_means.append(clipped_mean(network_eces_across_trials, 0, 50))

    figs, ax = plt.subplots(1, 2, figsize = (18, 9))
    figs.set_facecolor("white")

    ax[0].tick_params(axis='both', which='major', labelsize=27)
    ax[0].plot(n_ra[:len(KDE_acc_means)] , KDE_acc_means, label = "Ours", c = "red")
    ax[0].plot(n_ra[:len(KDE_acc_means)], network_acc_means, label = "Deep Network", c = "blue")
    ax[0].hlines(1.0, 0, n_ra[len(KDE_acc_means) - 1], linestyle = "dashed", label = "Perfect Class Prediction", color = "black")
    ax[0].legend(fontsize = 18)
    ax[0].set_xlabel("Number of Training Samples (logscale)", fontsize = 27)
    ax[0].set_ylabel("Test Accuracy", fontsize = 27)
    ax[0].set_xscale("log")

    ax[1].tick_params(axis='both', which='major', labelsize=27)
    ax[1].plot(n_ra[:len(KDE_acc_means)] , KDE_ece_means, label = "Ours", c = "red")
    ax[1].plot(n_ra[:len(KDE_acc_means)], network_ece_means, label = "Deep Network", c = "blue")
    ax[1].hlines(0.0, 0, n_ra[len(KDE_acc_means) - 1], linestyle = "dashed", label = "Perfect Calibration", color = "black")
    ax[1].legend(fontsize = 18)
    ax[1].set_xlabel("Number of Training Samples (logscale)", fontsize = 27)
    ax[1].set_ylabel("Test mECE", fontsize = 27)
    ax[1].set_xscale("log")

    figs.tight_layout()

    figs.suptitle("Posterior Estimation Comparison\nDataset: MNIST", fontsize=27, y = 1.15)

    plt.show()

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

In [None]:
figs, ax = plt.subplots(1, 2, figsize = (18, 9))
figs.set_facecolor("white")

ax[0].tick_params(axis='both', which='major', labelsize=27)
ax[0].plot(n_ra , KDE_acc_means, label = "Ours", c = "red")
ax[0].plot(n_ra, network_acc_means, label = "Deep Network", c = "blue")
ax[0].hlines(1.0, 0, n_ra[-1], linestyle = "dashed", label = "Perfect Class Prediction", color = "black")
ax[0].legend(fontsize = 18)
ax[0].set_xlabel("Number of Training Samples (logscale)", fontsize = 27)
ax[0].set_ylabel("Test Accuracy", fontsize = 27)
ax[0].set_xscale("log")

ax[1].tick_params(axis='both', which='major', labelsize=27)
ax[1].plot(n_ra , KDE_brier_means, label = "Ours", c = "red")
ax[1].plot(n_ra, network_brier_means, label = "Deep Network", c = "blue")
ax[1].hlines(0.0, 0, n_ra[-1], linestyle = "dashed", label = "Perfect Calibration", color = "black")
ax[1].legend(fontsize = 18)
ax[1].set_xlabel("Number of Training Samples (logscale)", fontsize = 27)
ax[1].set_ylabel("Test mECE", fontsize = 27)
ax[1].set_xscale("log")


figs.tight_layout()

figs.suptitle("Posterior Estimation Comparison\nDataset: MNIST", fontsize=27, y = 1.15)

plt.show()