In [1]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as scio

In [2]:
data = scio.loadmat('digit.mat')
X = data['X']
T = data['T']
[d,m,c] = T.shape

\begin{equation}
\hat{\mu}_{\mathrm{ML}}=\frac{1}{n} \sum_{i=1}^{n} x_{i}, \quad \hat{\Sigma}_{\mathrm{ML}}=\frac{1}{n} \sum_{i=1}^{n}\left(x_{i}-\hat{\mu}_{\mathrm{ML}}\right)\left(x_{i}-\hat{\mu}_{\mathrm{ML}}\right)^{\top}
\end{equation}

In [3]:
# 极大似然估计均值和方差
mu = np.zeros((d,c))
S = np.zeros((d,d))
for i in range(c):
    mu[:,i] = np.mean(X[:,:,i],1)
    S+=np.cov(X[:,:,i])/c
h = np.dot(np.linalg.inv(S),mu)

\begin{aligned}
\log \hat{p}(y \mid x)=& \boldsymbol{x}^{\mathrm{T}} \hat{\mathbf{\Sigma}}^{-1} \hat{\boldsymbol{\mu}}_{\mathrm{y}}-\frac{1}{2} \hat{\boldsymbol{\mu}}_{\mathrm{y}}^{\mathrm{T}} \hat{\mathbf{\Sigma}}^{-1} \hat{\mu}_{\mathrm{y}}+\log \frac{n_{y}}{n}+C^{\prime}
\end{aligned}

In [4]:
# 计算类后验概率
p = np.zeros((c,m,c))
for k in range(c):
    p[:,:,k]=np.dot(h.T,T[:,:,k])-np.tile(sum(np.multiply(mu,h)).reshape(c,1),(1,m))/2

\begin{equation}
\operatorname{sign}(x)=\left\{\begin{array}{ll}
1, & x>0 \\
0, & x=0 \\
-1, & x<0
\end{array}\right.
\end{equation}

\begin{equation}
\hat{y}=\underset{y}{\arg \max } p(y \mid x)
\end{equation}

In [5]:
# 最大后验概率规则
P = p.argmax(axis=0)

In [6]:
# 计算混淆矩阵
C = np.zeros((c,c))
for i in range(c):
    for j in range(c):
        C[i,j] = np.sum(P[:,i]==j)

In [7]:
C

array([[199.,   0.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.],
       [  0., 169.,   8.,   8.,   1.,   2.,   4.,   8.,   0.,   0.],
       [  0.,   0., 182.,   1.,   5.,   0.,   2.,   8.,   1.,   1.],
       [  2.,   2.,   0., 182.,   0.,   1.,   0.,   3.,  10.,   0.],
       [  0.,   0.,  21.,   4., 162.,   1.,   0.,   4.,   4.,   4.],
       [  1.,   2.,   0.,   1.,   5., 185.,   0.,   3.,   0.,   3.],
       [  2.,   0.,   1.,   5.,   1.,   0., 181.,   0.,   9.,   1.],
       [  0.,   1.,  16.,   6.,   6.,   0.,   1., 164.,   3.,   3.],
       [  1.,   0.,   0.,   8.,   0.,   0.,   7.,   2., 182.,   0.],
       [  0.,   0.,   3.,   0.,   0.,   4.,   0.,   1.,   0., 192.]])