In [268]:
from sklearn.datasets import fetch_openml
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [269]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]


def find_nearest_distances(D, y):
    n = D.shape[0]
    
    # Initialize arrays to store the results
    same_label_distances = np.full(n, np.inf)
    different_label_distances = np.full(n, np.inf)

    for i in range(n):
        # Create a mask for the same and different labels
        same_label_mask = (y == y[i])
        different_label_mask = (y != y[i])

        # Ignore the distance to itself by setting it to np.inf
        same_label_mask[i] = False

        # Extract distances for the same label
        if np.any(same_label_mask):
            same_label_distances[i] = np.min(D[i, same_label_mask])
        
        # Extract distances for the different label
        if np.any(different_label_mask):
            different_label_distances[i] = np.min(D[i, different_label_mask])
    
    return same_label_distances, different_label_distances

In [270]:
# dataset = fetch_openml(name="California-Housing-Classification", parser="auto")
# low_class_name = 'False'
# high_class_name = 'True'

dataset = fetch_openml(name="spambase", parser="auto")
low_class_name = '0'
high_class_name = '1'

# Really, we should shuffle the data to be on the safe side
X = dataset.data.values.astype(float)
Y = dataset.target.to_numpy()
Y = (Y == high_class_name).astype(int) - (Y == low_class_name).astype(int)

X, Y = unison_shuffled_copies(X, Y)

# N = 500
# X = X[:N]
# Y = Y[:N]

In [271]:
from classifiers import ConformalNearestNeighbours

cp = ConformalNearestNeighbours()

N = 300
cp.learn_initial_training_set(X[:N], Y[:N])

In [272]:
d = cp.distance_func(cp.X, X[N])
D = cp.update_distance_matrix(cp.D, d)

In [273]:
label_nc = 1
y_nc = np.append(cp.y, label_nc)

In [274]:
same_label_distances, different_label_distances = find_nearest_distances(D, y_nc)

In [275]:
D

array([[  0.        ,  66.58297354,  63.1210707 , ...,  78.1353908 ,
         53.2539726 , 659.22581347],
       [ 66.58297354,   0.        ,  22.78754484, ..., 142.48328916,
         88.15114851, 722.81098488],
       [ 63.1210707 ,  22.78754484,   0.        , ..., 140.43064227,
         85.27468665, 721.61675955],
       ...,
       [ 78.1353908 , 142.48328916, 140.43064227, ...,   0.        ,
         97.84017494, 581.5992864 ],
       [ 53.2539726 ,  88.15114851,  85.27468665, ...,  97.84017494,
          0.        , 660.20436731],
       [659.22581347, 722.81098488, 721.61675955, ..., 581.5992864 ,
        660.20436731,   0.        ]])

In [276]:
y_nc

array([-1, -1, -1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1,
       -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1,  1, -1,
       -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,  1,  1,
        1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,
       -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1,
       -1, -1,  1,  1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1,  1, -1,
       -1, -1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,
       -1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1,  1,
       -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,  1,  1,
       -1, -1, -1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1,
        1, -1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1, -1, -1, -1,  1,
       -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,
       -1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1, -1, -1,  1,  1,
        1,  1, -1,  1,  1

In [277]:
print("Nearest distances to an object with the same label:")
print(same_label_distances)
print("\nNearest distances to an object with a different label:")
print(different_label_distances)


Nearest distances to an object with the same label:
[5.30377017e+00 0.00000000e+00 8.78043849e+00 3.23867488e+00
 1.41965990e+01 1.50141395e+01 3.42768581e+00 1.21278507e+01
 6.12741112e+00 4.03599232e+00 1.03908662e+01 1.77504158e+01
 4.08581559e+00 3.02636754e+01 1.19621910e+01 3.83118493e+00
 3.22189352e+02 3.17879238e+00 1.17751849e+01 1.29832853e+01
 6.36953264e+00 9.08616272e+00 7.75144019e+00 2.84282885e+01
 4.15778787e+00 0.00000000e+00 3.51311443e+00 3.64742226e+01
 1.22938952e+01 1.00000000e+00 1.04489205e+01 1.11801618e+01
 2.15227581e+01 5.78050223e+00 6.14418986e+00 3.67038758e+00
 4.76099506e+00 6.62015273e+01 8.71417615e+00 1.72038196e+02
 2.64530652e+01 2.21777837e+01 4.21573161e+00 3.13797646e+01
 3.89082112e+00 5.60401071e+00 4.02196797e+01 0.00000000e+00
 5.94899302e+00 3.48922377e+02 1.41843821e+01 2.42266836e+00
 5.10347627e+00 1.92303302e+01 3.75323017e+01 1.43722545e+00
 4.00058071e+00 2.77406504e+01 8.37119830e+00 5.02382344e+00
 0.00000000e+00 6.07927627e+00 2.

In [278]:
same_label_distances / different_label_distances

array([1.02515491e+00, 0.00000000e+00, 1.05918414e+00, 5.55395507e-01,
       4.11125646e-01, 2.87426691e+00, 4.90603574e-01, 2.55000818e+00,
       5.89282734e-01, 4.26391445e-01, 1.45944602e+00, 1.07474893e+00,
       4.35644245e-01, 6.78445245e-01, 5.81857172e-01, 5.57876584e-01,
       1.05685454e+00, 4.97316751e-01, 4.77126397e-01, 1.31150922e+00,
       4.03173748e-01, 1.33639810e+00, 7.43069180e-01, 2.80820646e+00,
       1.20245650e+00, 0.00000000e+00, 5.76685701e-01, 3.29469648e+00,
       1.59761065e+00, 2.10084034e-01, 3.86141753e-01, 1.36449897e+00,
       5.75292230e-01, 9.12870931e-01, 5.99677407e-01, 9.20175591e-01,
       7.06762963e-01, 6.45859796e-01, 1.29205025e+00, 2.45524928e-01,
       9.54535845e-01, 2.00330701e+00, 7.24111315e-01, 4.99740520e+00,
       4.23435064e-01, 1.14954724e+00, 4.71974955e+00, 0.00000000e+00,
       1.26400329e+00, 1.15132533e+00, 1.75720522e+00, 5.60310422e-02,
       3.26733302e-01, 9.91191983e-01, 3.71210915e+00, 2.12988184e-01,
      

In [282]:
np.nan_to_num(np.array([0]) / np.array([0]), nan=np.inf)

  np.nan_to_num(np.array([0]) / np.array([0]), nan=np.inf)


array([inf])