In [1]:
from sklearn.datasets import load_wine

X, y = load_wine(return_X_y=True)

In [None]:
import numpy as np
from scipy.spatial import KDTree
import cvxpy as cp

class DKNN:

    def __init__(self):
        self.A = None
        self.pi = None
        self.trees = []

    def dist(self, x, mu, c):
        delta = x - mu
        return delta @ self.A @ delta.T - np.log(self.pi[c])

    def fit(self, X, y, k, alpha=1, beta=1):
        C = np.unique(y)
        c_idx = []
        for ci in C:
            c_idx.append(np.where(y == ci))
        n, d = X.shape

        centroids = []

        # Find centroids of class C[i]
        for idx in c_idx:
            # Get k nearest neighbors of class C[i] for all X
            tree = KDTree(X[idx])
            _, n_idx = tree.query(X, k)
            self.trees.append(tree)

            # Compute centroids
            neighbors = X[idx][n_idx]
            centroid_c = np.mean(neighbors, axis=1)
            centroids.append(centroid_c)
        
        centroids = np.stack(centroids, axis=0)

        # Convex problem formulation
        self.pi = np.array([len(idx[0]) / n for idx in c_idx])
        self.A = cp.Variable((d, d))

        # self.A = np.random.random((d, d))

        delta = X - centroids

        # should work
        # f_mult = np.sum(np.multiply(delta @ self.A, delta), axis=2) - self.pi[:, np.newaxis]
        # print(f_mult[0, 0])

        constraints = []
        epsilon = cp.Variable(n)
        constraints.append(epsilon >= 0)

        for i in range(n):
            for c in C:
                if c == y[i]:
                    continue
                constraints += [
                    delta[y[i], i] @ self.A @ delta[y[i], i] - cp.log(self.pi[y[i]]) + 1 - epsilon[i] <= delta[c, i] @ self.A @ delta[c, i].T - cp.log(self.pi[c])
                ]
            constraints += [
                epsilon[i] >= 0
            ]
        
        objective = cp.Minimize(cp.sum(alpha * epsilon) + beta * cp.norm(self.A))

        prob = cp.Problem(objective, constraints)
        prob.solve()

        self.A = self.A.value

    def predict(self, x, k):
        pass
    
    
dknn = DKNN()
dknn.fit(X, y, 5)