## <font color='green'> <div align="center">In the name of God </div></font>

### <font color='red'> Author: Sayed Kamaledin Ghiasi-Shrirazi <a href="http://profsite.um.ac.ir/~k.ghiasi">(http://profsite.um.ac.ir/~k.ghiasi)</a> </font>

# GMM Quadratic Classifier on MNIST with reg = 0.1

##### Importing general modules.

In [1]:
import numpy as np
from sklearn import mixture
import scipy.io as sio
import matplotlib as mpl
import matplotlib.pyplot as plt
from tqdm import tqdm

### Preparing matplotlib to genrate high-quality images for the paper

In [2]:
%matplotlib inline
mpl.rcParams['figure.dpi']= 600

In [3]:
C  = 10
K = 6
L = C * K
width = 28
height = 28
reg = 0.1

### Load Training and Testing Data and Normalize them

In [4]:
XTrain = sio.loadmat ('../../datasets/mnist/MnistTrainX')['MnistTrainX']
XTrain = XTrain / 255
yTrain = sio.loadmat ('../../datasets/mnist/MnistTrainY')['MnistTrainY']
XTest = sio.loadmat ('../../datasets/mnist/MnistTestX')['MnistTestX']
XTest = XTest / 255
yTest = sio.loadmat ('../../datasets/mnist/MnistTestY')['MnistTestY']
N, dim = XTrain.shape

### Loading GMM data computed on Google Colab

In [5]:
clusters = np.zeros(N, dtype = int)
with open('GMM_MNIST_6.txt', 'r') as f:
    for i in range (N):
        s = f.readline()
        clusters[i] = int(s)

### Compute means and covariances

In [6]:
means = np.zeros([L, dim])
cov = np.zeros([L, dim, dim])
for i in range (L):
    idx = clusters == i
    means[i,:] = np.mean(XTrain[idx,:], axis = 0)
    cov[i,:,:] = np.cov (XTrain[idx,:].T)    

### Compute priors

In [9]:
Nck = np.bincount(clusters)
Pck = Nck / N

In [10]:
logProb = np.zeros ([N,L])
for i in tqdm(range (L)):
    eigvals, eigvecs = np.linalg.eig(cov[i,:,:])
    eigvals = eigvals + reg
    sigmaHalfInv = eigvecs @ np.diag(np.sqrt(1/eigvals)) @ eigvecs.T
    X = (XTrain - means[i,:]) @ sigmaHalfInv
    X = X ** 2
    d2 = np.sum(X, axis = 1)
    logProb[:,i] = np.log(Pck[i]) - 1 / 2 * d2
prediction = np.argmax (logProb, axis = 1) // K
score = np.sum(prediction == yTrain.squeeze())/ N
print ('Training Accuracy = ', score)


  # Remove the CWD from sys.path while we load stuff.

  2%|█▍                                                                                 | 1/60 [00:11<11:28, 11.67s/it]
  3%|██▊                                                                                | 2/60 [00:23<11:24, 11.80s/it]
  5%|████▏                                                                              | 3/60 [00:35<11:09, 11.75s/it]
  7%|█████▌                                                                             | 4/60 [00:47<11:08, 11.94s/it]
  8%|██████▉                                                                            | 5/60 [00:58<10:37, 11.60s/it]
 10%|████████▎                                                                          | 6/60 [01:10<10:32, 11.71s/it]
 12%|█████████▋                                                                         | 7/60 [01:23<10:33, 11.96s/it]
 13%|███████████                                                                        | 8/60 [01:34<10

Training Accuracy =  0.9815666666666667


In [11]:
NTest = XTest.shape[0]
logProb = np.zeros ([NTest,L])
for i in tqdm(range (L)):
    eigvals, eigvecs = np.linalg.eig(cov[i,:,:])
    eigvals = eigvals + reg
    sigmaHalfInv = eigvecs @ np.diag(np.sqrt(1/eigvals)) @ eigvecs.T
    X = (XTest - means[i,:]) @ sigmaHalfInv
    X = X ** 2
    d2 = np.sum(X, axis = 1)
    logProb[:,i] = np.log(Pck[i]) - 1 / 2 * d2
prediction = np.argmax (logProb, axis = 1) // K
score = np.sum(prediction == yTest.squeeze())/ NTest
print ('Testing Accuracy = ', score)


  # This is added back by InteractiveShellApp.init_path()

  2%|█▍                                                                                 | 1/60 [00:03<03:14,  3.29s/it]
  3%|██▊                                                                                | 2/60 [00:06<03:06,  3.22s/it]
  5%|████▏                                                                              | 3/60 [00:09<03:02,  3.20s/it]
  7%|█████▌                                                                             | 4/60 [00:12<02:57,  3.17s/it]
  8%|██████▉                                                                            | 5/60 [00:15<02:43,  2.98s/it]
 10%|████████▎                                                                          | 6/60 [00:18<02:43,  3.04s/it]
 12%|█████████▋                                                                         | 7/60 [00:21<02:36,  2.96s/it]
 13%|███████████                                                                        | 8/60 [00:2

Testing Accuracy =  0.9701


### <font color='red'> Author: Sayed Kamaledin Ghiasi-Shrirazi <a href="http://profsite.um.ac.ir/~k.ghiasi">(http://profsite.um.ac.ir/~k.ghiasi)</a> </font>