In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.neural_network import MLPClassifier as NN

from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
from sklearn.gaussian_process.kernels import RBF
from tqdm import tqdm
from scipy.sparse import csr_matrix, csgraph

import numba
from numba import jit
from multiprocessing import Pool
from joblib import Parallel
import math
import dill

In [2]:
# Load the dataset
# https://www.python-course.eu/neural_network_mnist.php
image_size = 28  # width and length
no_of_different_labels = 10  # i.e. 0, 1, 2, 3, ..., 9
image_pixels = image_size * image_size
data_path = "mnist_data/mnist_train.csv"
# train_data = np.loadtxt(data_path, delimiter=",")
# fac = 0.99 / 255
# train_imgs = np.asfarray(train_data[:, 1:]) * fac + 0.01
# train_ilabels = np.asfarray(train_data[:, :1]).ravel()

train_data = pd.read_csv(data_path,
                         delimiter=",", header=None, dtype=np.uint8)
# divide by 255 only for model prediction
train_imgs = train_data.iloc[:, 1:]/255
train_imgs = train_imgs.astype("float32")
# train_imgs = train_data.iloc[:, 1:]
train_labels = np.asarray(train_data.iloc[:, :1][0])

In [3]:
# subsample the original dataset, each digit has 1000 instances rather than 5000-6000

all_indices = []
# for digit in [0, 1, 2, 4, 7, 8, 9]:
for digit in np.arange(0, 10):
    single_digit_indices = set(np.where(train_labels == digit)[0][:1000])
    all_indices = set(all_indices).union(single_digit_indices)
print(len(all_indices))

X = train_imgs.iloc[list(all_indices), :]
y = train_labels[list(all_indices)]

10000


In [4]:
@jit(nopython=True)
def calculate_weighted_distance(v0, v1, penalty_term=2):
    diff = np.subtract(v0, v1)
    reweight_vector = np.where(diff >= 0, 1, -penalty_term)
    weighted_diff = np.linalg.norm(diff*reweight_vector)
    return weighted_diff


def get_weights_kNN(
    X,
    n_neighbours=20,
    penalty_term=2,
    weight_func=None
):
    n_samples, n_ftrs = X.shape

    k = np.zeros((n_samples, n_samples))
    W = k
    X = X.to_numpy()

    for i in tqdm(range(n_samples)):
        v0 = X[i]
        for j in range(n_samples):
            v1 = X[j]
            # modify the distance function so that removing pixels incurring larger cost.
            dist = calculate_weighted_distance(
                v1, v0, penalty_term=penalty_term)
            # dist = np.linalg.norm(v0 - v1)
            k[i, j] = dist
            if dist != 0:
                W[i, j] = weight_func(dist)

        t = np.argsort(k[i, :])[(n_neighbours+1):]
        mask = np.ix_(t)
        k[i, mask] = 0
        W[i, mask] = 0

    return W

In [5]:
def construct_graph(weight_matrix):
    graph = csr_matrix(weight_matrix)
    return graph


def find_shortest_path(graph, start_point_idx):
    dist_matrix, predecessors = csgraph.dijkstra(
        csgraph=graph, directed=True, indices=start_point_idx, return_predecessors=True
    )
    return dist_matrix, predecessors


def reconstruct_shortest_path(predecessors, start_point_idx, end_point_idx):
    """Get all the nodes along the path between the start point and the end point. 

    Args:
        predecessors (matrix of shape (1, n_nodes)): contain the previous node in the path.
        start_point_idx (int): the index of the start data point
        end_point_idx (int): the index of the end data point

    Returns:
        node_path (list): [start_point_idx, intermedium points index, end_point_idx]
    """
    if predecessors[end_point_idx] == start_point_idx:
        node_path = [end_point_idx]
    else:
        node_path = []
    intermedium_idx = end_point_idx
    while (predecessors[intermedium_idx] != start_point_idx):
        node_path.append(intermedium_idx)
        intermedium_idx = predecessors[intermedium_idx]
    if intermedium_idx != node_path[-1]:
        node_path.append(intermedium_idx)
    node_path.append(start_point_idx)

    return node_path[::-1]


def build_symmetric_matrix(kernel):
    for i in range(kernel.shape[0]):
        for j in range(i):
            kernel[j, i] = kernel[i, j]
    return kernel


def build_asymmetric_matrix(kernel, X, weight_func, penalty_term):
    n_samples = kernel.shape[0]
    X = X.to_numpy()
    for i in tqdm(range(n_samples)):
        for j in range(n_samples):
            if kernel[i, j] != 0:
                v0 = X[i]
                v1 = X[j]
                dist = calculate_weighted_distance(
                    v0, v1, penalty_term=penalty_term)
                kernel[j, i] = weight_func(dist)
    return kernel

In [None]:
n_neighbours = 20
penalty_term = 1.1
n_samples, n_features = X.shape


def get_volume_of_sphere(d):
    return math.pi**(d/2)/math.gamma(d/2 + 1)


volume_sphere = get_volume_of_sphere(1)
r = (n_neighbours / (n_samples * volume_sphere))

# Construct the global weighted graph.
# Kernel is asymmetric if using KNN to get weight, and OG only keeps the bottom left half of the matrix


def weight_func(x): return -x*np.log(r/x)  # x**alpha


kernel = get_weights_kNN(
    X,
    penalty_term=penalty_term,
    n_neighbours=int(n_neighbours),
    weight_func=weight_func
)

# sym_kernel = build_symmetric_matrix(kernel)
# asym_kernel = build_asymmetric_matrix(kernel, X, weight_func)

# Replace FOR with parallel FOR (multi-processors emia..)
# On local laptop, set the number of processors to 2
# On server, set the number of processors to #of available - 1

 40%|███▉      | 3976/10000 [02:36<03:57, 25.37it/s]


KeyboardInterrupt: 

In [None]:
asym_kernel = build_asymmetric_matrix(kernel, X, weight_func, penalty_term)

dill.dump_session("mnist_graph_construction.db")

In [None]:
def get_minimum_dist(dist_matrix):
    """get the shortest distance and its data index
    Args:
        dist_matrix (array): shape: 1 x n_nodes

    Returns:
        min_dist: minimum distance in the distance matrix
        min_dist_idx: index of the data point with the shortest dist
    """
    min_dist = np.min(np.ma.masked_where(
        dist_matrix == 0, dist_matrix, copy=False))
    min_dist_idx = np.argmin(np.ma.masked_where(
        dist_matrix == 0, dist_matrix, copy=False))
    return min_dist, min_dist_idx


def get_closest_cf_point(dist_matrix, predictions, y, target_class, class_labels, num_paths=1, pred_threshold=0.55):
    assert num_paths > 0 and isinstance(
        num_paths, int), "only positive integers"
    end_point_idx = []
    path_count = 0
    for idx in np.argsort(np.ma.masked_where(dist_matrix == 0, dist_matrix)):
        if (y[idx] == target_class and
                predictions[idx, class_labels.index(target_class)] >= pred_threshold):
            end_point_idx.append(idx)
            if path_count >= num_paths-1:
                break
            else:
                path_count += 1
    return end_point_idx


def get_user_agency(sp_graph, start_point_idx, alternative_classes, predictions, y, class_labels, pred_threshold=0.55):
    dist_matrix, predecessors = find_shortest_path(sp_graph, start_point_idx)
    alt_class_dict = {}
    alt_path_dict = {}
    for alt_class in alternative_classes:
        alt_end_idx = get_closest_cf_point(
            dist_matrix, predictions, y, alt_class, class_labels, pred_threshold=pred_threshold)
        alt_end_dist = dist_matrix[alt_end_idx[0]]
        alt_class_dict[alt_class] = {alt_end_idx[0]: alt_end_dist}

        alt_end_path = reconstruct_shortest_path(
            predecessors, start_point_idx, alt_end_idx[0])
        alt_path_dict[alt_class] = {alt_end_idx[0]: alt_end_path}

    return alt_class_dict, alt_path_dict

In [None]:
def plot_digits(path, _id):
    n_digits = len(path)
    ncols = min(10, n_digits)
    nrows = math.ceil(n_digits / ncols)
    _, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 4*nrows))
    axes = axes.ravel()
    for idx, img_idx in enumerate(path):
        ax = axes[idx]
        image = np.array(X.iloc[img_idx, :])
        label = y[img_idx]
        ax.set_axis_off()
        ax.imshow(image.reshape(28, 28), cmap=plt.cm.gray_r,
                  interpolation="nearest")
        ax.set_title(f"id:{_id}, t: {label}")
    # plt.savefig(f'example_{_id}_log.pdf', dpi=300)


def print_alt_paths(alt_path_dict):
    for alt_class, alt_class_path in alt_path_dict.items():
        plot_digits(list(alt_class_path.items())[
                    0][1], '-'.join(map(str, [alt_class, 0])))