Reference: https://sebastianraschka.com/Articles/2014_python_lda.html

In [1]:
import numpy as np
import json
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets

np.random.seed(1)
plt.style.use('classic')

In [2]:
feature_dict = {i:label for i,label in zip(
                range(4),
                  ('sepal length in cm',
                  'sepal width in cm',
                  'petal length in cm',
                  'petal width in cm', ))}

In [3]:
iris = datasets.load_iris()
X = iris.data
y = iris.target

In [4]:
X.shape, y.shape

((150, 4), (150,))

Promedio de cada caracteristica para cada clase

$$
\begin{align}
\mu_i = \frac{1}{n_i} \sum_{i=1}^{C} \sum_{j=0}^{n} x_n \cdot e^{-i2\pi\frac{n}{N}}
\end{align}
$$

In [5]:
np.set_printoptions(precision=4)

mean_vectors = []
for cl in range(0,3):
    mean_vectors.append(np.mean(X[y==cl], axis=0))
    print('Mean Vector class %s: %s\n' %(cl, mean_vectors[cl]))

Mean Vector class 0: [5.006 3.428 1.462 0.246]

Mean Vector class 1: [5.936 2.77  4.26  1.326]

Mean Vector class 2: [6.588 2.974 5.552 2.026]



$$
\begin{align}
S_W = \sum_{i=1}^{C} \sum_{j=0}^{n} x_n \cdot e^{-i2\pi\frac{n}{N}}
\end{align}
$$

In [6]:
S_W = np.zeros((4,4))
for cl,mv in zip(range(1,4), mean_vectors):
    class_sc_mat = np.zeros((4,4))                  # scatter matrix for every class
    for row in X[y == cl]:
        row, mv = row.reshape(4,1), mv.reshape(4,1) # make column vectors
        class_sc_mat += (row-mv).dot((row-mv).T)
    S_W += class_sc_mat                             # sum class scatter matrices
print('within-class Scatter Matrix:\n', S_W)

within-class Scatter Matrix:
 [[ 97.3682 -15.1782 196.0494  78.1788]
 [-15.1782  33.6502 -71.3282 -24.0392]
 [196.0494 -71.3282 500.6482 202.2864]
 [ 78.1788 -24.0392 202.2864  88.4324]]


In [7]:
S_W.shape

(4, 4)

In [8]:
np.linalg.inv(S_W)

array([[ 0.0879, -0.0591, -0.0654,  0.0558],
       [-0.0591,  0.0895,  0.0654, -0.073 ],
       [-0.0654,  0.0654,  0.0842, -0.1171],
       [ 0.0558, -0.073 , -0.1171,  0.21  ]])