In [12]:
import numpy as np   

class GDAClassifier:
    
    def fit(self, X, y, epsilon = 1e-10):
        self.y_classes, y_counts = np.unique(y, return_counts=True)
        self.phi_y = 1.0 * y_counts/len(y)
        self.u = np.array([ X[y==k].mean(axis=0) for k in self.y_classes])
        self.E = self.compute_sigma(X, y)
        self.E += np.ones_like(self.E) * epsilon # fix zero overflow
        self.invE = np.linalg.pinv(self.E)
        return self
    
    def compute_sigma(self,X, y):
        X_u = X.copy().astype('float64')
        for i in range(len(self.u)):
            X_u[y==self.y_classes[i]] -= self.u[i]
        return X_u.T.dot(X_u) / len(y)

    def predict(self, X):
        return np.apply_along_axis(self.get_prob, 1, X)
    
    def score(self, X, y):
        return (self.predict(X) == y).mean()
    
    def get_prob(self, x):
        p = np.exp(-0.5 * np.sum((x - self.u).dot(self.invE) * (x - self.u), axis =1)) * self.phi_y
        print(p)
        return np.argmax(p)

In [2]:
from utils import train_test_split
from sklearn.datasets import load_iris
X,y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.8)
model = GDAClassifier().fit(X_train,y_train)

In [3]:
model.score(X_test,y_test)

0.9833333333333333

In [5]:
from sklearn.datasets import load_breast_cancer
X,y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.8)
model = GDAClassifier().fit(X_train,y_train)
model.score(X_test,y_test)

0.9296703296703297

In [13]:
from sklearn.datasets import load_digits
digits = load_digits()
X = digits.data
y = digits.target
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.5)
model = GDAClassifier().fit(X,y)
model.score(X_test,y_test)

[9.14233744e-25 2.57076228e-10 1.90884518e-16 1.14832779e-13
 1.46538603e-24 1.75152322e-17 4.01887541e-20 5.75896086e-20
 8.94579985e-10 2.63564377e-11]
[2.26877709e-20 1.19463211e-16 1.07073732e-19 1.85711201e-12
 3.42908808e-26 3.40130499e-20 5.03866710e-25 3.36419548e-21
 4.87212297e-15 3.44339079e-09]
[6.82376533e-35 8.62574591e-24 6.26355822e-27 1.55875650e-27
 1.51276440e-32 1.41838179e-31 6.76961914e-34 3.15201985e-31
 1.71330850e-18 2.62401322e-24]
[9.15096863e-50 2.80849510e-30 1.28992394e-27 2.66100148e-41
 9.01823410e-43 1.37806333e-54 6.93462215e-44 2.69300569e-41
 1.75240102e-35 4.85600574e-43]
[1.60367823e-26 2.16482192e-24 3.06580959e-28 1.03790480e-21
 6.56588668e-35 2.86328740e-09 4.51872807e-27 3.02817276e-22
 1.19358863e-22 1.06121454e-21]
[2.57942067e-31 1.27062223e-22 1.17519633e-30 9.63284121e-23
 1.15280979e-29 1.54582801e-20 5.63896950e-27 4.31892705e-21
 3.29214334e-17 6.55959671e-18]
[7.88794289e-18 9.23622408e-17 6.23692974e-19 6.96339741e-12
 3.62408081e-23

[3.55015581e-30 2.13296479e-22 5.83587877e-28 1.42826828e-22
 3.55952946e-26 5.46957086e-25 9.30168224e-31 6.49357082e-13
 2.35654204e-19 6.37587424e-19]
[1.75509193e-26 2.27145837e-20 4.17041124e-24 2.68607548e-18
 7.39526538e-33 1.84107302e-20 4.09080801e-21 4.59955079e-23
 7.21527907e-11 5.27603843e-17]
[3.47404845e-37 7.48925732e-25 7.75036969e-26 5.25757701e-19
 7.74244876e-40 9.64008496e-24 2.73703136e-36 3.74070181e-24
 1.18248339e-25 8.90718263e-24]
[8.33171098e-19 3.18415540e-19 9.26344355e-25 2.00325830e-22
 4.26976277e-18 7.83289386e-21 1.22632441e-07 5.79767639e-25
 2.34146430e-17 1.99591485e-22]
[5.05481581e-30 2.86319702e-20 9.10238700e-26 4.74794960e-23
 1.28957434e-32 1.26450099e-26 7.97908934e-23 1.29673532e-26
 1.15563209e-16 2.03741730e-22]
[2.25743229e-29 9.88829976e-26 8.84853784e-31 3.71784393e-25
 5.89785728e-26 3.87495393e-21 7.01783515e-28 4.74193653e-12
 9.40466438e-24 2.91723496e-21]
[7.12101482e-30 2.86193048e-20 2.21540719e-11 2.10679109e-19
 8.24171852e-35

0.965478841870824

In [9]:
model.E

array([[ 1.00000000e-10,  1.00000000e-10,  1.00000000e-10, ...,
         1.00000000e-10,  1.00000000e-10,  1.00000000e-10],
       [ 1.00000000e-10,  6.85260934e-01,  1.31558591e+00, ...,
        -2.49550674e-01, -1.72916333e-01, -1.47619408e-01],
       [ 1.00000000e-10,  1.31558591e+00,  1.24547655e+01, ...,
        -1.46806270e+00, -3.44024003e-02,  3.11114596e-01],
       ...,
       [ 1.00000000e-10, -2.49550674e-01, -1.46806270e+00, ...,
         1.83142211e+01,  7.67224211e+00,  1.55782058e+00],
       [ 1.00000000e-10, -1.72916333e-01, -3.44024003e-02, ...,
         7.67224211e+00,  1.09016392e+01,  3.41868286e+00],
       [ 1.00000000e-10, -1.47619408e-01,  3.11114596e-01, ...,
         1.55782058e+00,  3.41868286e+00,  2.99539167e+00]])