In [4]:
from sklearn.datasets import load_wine

%load_ext autoreload
%autoreload 2

X, y = load_wine(return_X_y=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import numpy as np
from scipy.spatial import KDTree
import cvxpy as cp
from sklearn.base import BaseEstimator, ClassifierMixin

class DKNN(BaseEstimator, ClassifierMixin):
    def __init__(self, k, alpha=1, beta=1):
        super().__init__()
        self.k = k          # 'k' neighbors
        self.A = None       # PSD matrix objective
        self.pi = None      # technically log(pi)
        self.trees = []     # search tree for NN
        
        # importance weights for each class (k,)
        if type(alpha) in {float, int}:
            self.alpha = np.full(k, alpha)
        else:
            assert k == len(alpha)
            self.alpha = alpha
        self.beta  = beta   # regularization term

    # Mahalanobis distance
    def dist(self, x, mu, c):
        delta = x - mu
        return np.sum(np.multiply(delta @ self.A, delta), axis=-1) - self.pi[c]

    def fit(self, X, y):
        self.trees = [] # reset for each fit -- when using CV need this
        self.X = X
        self.C = np.unique(y)
        self.classes_ = self.C
        self.c_idx = []  # indices belonging to 'c' w.r.t full training X
        for ci in self.C:
            self.c_idx.append(np.where(y == ci))
        n, d = X.shape

        centroids = []

        # Find centroids of class C[i]
        for idx in self.c_idx:
            # Get k nearest neighbors of class C[i] for all training data X
            tree = KDTree(X[idx])
            _, n_idx = tree.query(X, self.k)
            self.trees.append(tree)

            # Compute centroids
            neighbors = X[idx][n_idx] # X[of class 'c'][its nearest neighbors w.r.t X[c]]
            if self.k == 1:
                centroid_c = neighbors
            else:
                centroid_c = np.mean(neighbors, axis=1)
            centroids.append(centroid_c)
        
        centroids = np.stack(centroids, axis=0)

        # Convex problem formulation
        self.pi = np.array([len(idx[0]) / n for idx in self.c_idx])
        self.A = cp.Variable((d, d))

        delta = X - centroids

        # should work
        # f_mult = np.sum(np.multiply(delta @ self.A, delta), axis=2) - self.pi[:, np.newaxis]
        # print(f_mult[0, 0])

        constraints = []
        epsilon = cp.Variable(n)
        constraints.append(epsilon >= 0)

        for i in range(n):
            for c in self.C:
                if c == y[i]:
                    continue
                constraints += [
                    delta[y[i], i] @ self.A @ delta[y[i], i].T - cp.log(self.pi[y[i]]) + 1 - epsilon[i] 
                    <= delta[c, i] @ self.A @ delta[c   , i].T - cp.log(self.pi[c])
                ]
            constraints += [
                epsilon[i] >= 0
            ]
        
        alpha_vec = np.array([self.alpha[y_i] for y_i in y])  # corresponding class importance weight
        objective = cp.Minimize(cp.sum(cp.multiply(alpha_vec, epsilon)) + self.beta * cp.norm(self.A))

        prob = cp.Problem(objective, constraints)
        prob.solve()

        self.A = self.A.value

    def predict(self, X_new):
        if X_new.ndim == 1:
            n = 1
        else:
            n = X_new.shape[0]

        dist_c = np.empty((n, len(self.trees)))
        for c, t in enumerate(self.trees):
            # each tree 't' is already a subset of X conditioned on y=c
            _, n_idx = t.query(X_new, self.k)

            # Compute centroids
            neighbors = self.X[self.c_idx[c][0][n_idx]]
            centroid = np.mean(neighbors, axis=-2)
            cur_dist = self.dist(X_new, centroid, c)
            dist_c[:, c] = cur_dist
        
        predictions = np.argmin(dist_c, axis=1)
        return predictions
    
    def score(self, X_test, y_test):
        y_pred = self.predict(X_test)
        return np.average(y_pred == y_test)
    
    def get_params(self, deep=False):
        return {
            'k': self.k,
            'alpha': self.alpha,
            'beta': self.beta,
        }
    
    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key, value)
        return self

In [8]:
data, labels = ds.get_UCI_dataset("wine")
dknn_clf = DKNN(k=3, alpha=[1.0, 0.5, 1.0], beta=0.01)
ds.accuracy_splits(data, labels, dknn_clf)

(0.8314814814814815,
 0.054590381381421006,
 [0.8333333333333334,
  0.8888888888888888,
  0.8148148148148148,
  0.7777777777777778,
  0.9629629629629629,
  0.8148148148148148,
  0.7777777777777778,
  0.8333333333333334,
  0.8333333333333334,
  0.7777777777777778])

In [9]:
import dataset_helpers as ds
from sklearn.model_selection import GridSearchCV, ShuffleSplit

param_grid = {
    'k': [3, 5, 7, 10],
    'alpha': [[0.1, 0.5, 1.0], [0.1, 0.5, 2.0]],  # Example values for alpha
    'beta': [0.01, 1.0],    # Example values for beta
}
data, labels = ds.get_UCI_dataset("wine")
dknn_clf = DKNN(k=3)
grid_clf = GridSearchCV(dknn_clf, param_grid, scoring="accuracy",)
rs = ShuffleSplit(n_splits=10, test_size=0.3)
for i, (train_index, test_index) in enumerate(rs.split(X)):
    grid_clf.fit(data[train_index], labels[train_index])
    print(grid_clf.best_params_)
    acc = grid_clf.score(data[test_index], labels[test_index])
    print(acc)
    print(grid_clf.cv_results_)

    # grid_clf.append(acc)
# mean_acc, std_acc, acc_record = ds.accuracy(data, labels, dknn_clf)

{'alpha': [0.1, 0.5, 1.0], 'beta': 0.01, 'k': 5}
0.8148148148148148
{'mean_fit_time': array([0.34925036, 0.28249536, 0.30054102, 0.28602157, 1.87036734,
       1.18856335, 2.58721328, 0.88224564, 0.34230032, 0.27098632,
       0.28218288, 0.28251481, 1.50718341, 0.87025876, 1.68085308,
       0.69888706]), 'std_fit_time': array([0.08014264, 0.03122216, 0.04510174, 0.02965751, 0.60739583,
       0.72693089, 2.36343378, 0.33980864, 0.07044298, 0.0176769 ,
       0.02663231, 0.01664443, 0.59180549, 0.58243253, 1.64723396,
       0.29901592]), 'mean_score_time': array([0.00045352, 0.00044508, 0.00045681, 0.00045686, 0.00045066,
       0.00045271, 0.00047116, 0.00046177, 0.00043316, 0.00044508,
       0.00043502, 0.0004498 , 0.00044403, 0.00044279, 0.00045152,
       0.0004796 ]), 'std_score_time': array([2.08952633e-05, 2.50516103e-05, 1.86698646e-05, 2.34584497e-05,
       7.80674324e-06, 1.15630799e-05, 1.47121071e-05, 3.39325845e-06,
       6.98757634e-06, 1.15276308e-05, 1.66132420e-05



{'alpha': [0.1, 0.5, 1.0], 'beta': 0.01, 'k': 10}
0.8703703703703703
{'mean_fit_time': array([3.37666245, 0.48657761, 0.32411036, 0.32753935, 4.64140768,
       2.2729887 , 1.71666398, 1.40433183, 0.68212976, 0.33133478,
       0.28448114, 0.28423777, 4.0323525 , 1.31621814, 1.20215578,
       0.92798576]), 'std_fit_time': array([5.65443218, 0.3058039 , 0.05199331, 0.02217558, 4.53725607,
       1.08433028, 0.80749519, 0.35077267, 0.63126131, 0.07358484,
       0.03822279, 0.02108363, 4.78131917, 0.5736563 , 0.57396379,
       0.5709583 ]), 'mean_score_time': array([0.00051069, 0.00043454, 0.00045009, 0.00047812, 0.00050902,
       0.00045691, 0.00046139, 0.00047464, 0.00043521, 0.00044136,
       0.00043659, 0.00045042, 0.00048628, 0.00044899, 0.00047102,
       0.00047565]), 'std_score_time': array([1.46901061e-04, 1.28486159e-05, 6.72052877e-06, 5.24361760e-05,
       1.22459172e-04, 6.81993835e-06, 8.43125767e-06, 1.30529763e-05,
       1.64723602e-05, 1.46803905e-05, 1.00582127e-0



{'alpha': [0.1, 0.5, 1.0], 'beta': 0.01, 'k': 10}
0.8888888888888888
{'mean_fit_time': array([1.21876655, 0.28629069, 0.30367856, 0.38908844, 3.63863883,
       0.94560299, 0.61804185, 2.21893787, 3.29544978, 0.27650437,
       0.29066896, 0.33227983, 3.59788561, 0.65218592, 0.57679143,
       1.57841153]), 'std_fit_time': array([1.64391101, 0.02677398, 0.06963847, 0.1252659 , 5.08438736,
       0.26109547, 0.28500992, 1.27439643, 5.91374283, 0.0214895 ,
       0.05881478, 0.06343589, 5.22810873, 0.18266914, 0.27489919,
       0.70796279]), 'mean_score_time': array([0.00043478, 0.00043368, 0.00046082, 0.00049734, 0.00052643,
       0.00046158, 0.00045118, 0.00048742, 0.00049005, 0.00042863,
       0.00043311, 0.00045071, 0.00049362, 0.00044103, 0.00044861,
       0.00047817]), 'std_score_time': array([1.03503324e-05, 8.93481379e-06, 2.08745780e-05, 5.05152652e-05,
       1.22625701e-04, 1.04164638e-05, 1.05988791e-05, 2.74687176e-05,
       7.17315252e-05, 1.95917992e-05, 1.06459702e-0



KeyboardInterrupt: 