<a href="https://colab.research.google.com/github/camlab-bioml/2021_IMC_Jett/blob/main/GMM_class_torch_EM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
class GMM(object):
  
  def __init__(self, X, k=3):
    self.X = np.asarray(X)
    
    # dimension
    self.n, self.p = X.shape
    #print(n)
    #print(p)
    # number of mixtures
    self.k = k

  def _init(self):
    # init mixture means/sigmas
    self.nj = torch.zeros(self.k)
    self.wj = torch.ones(self.k) / self.k
    self.muj = torch.ones(self.k, self.p)
    self.coj = torch.zeros(self.k, self.p, self.p)
    self.rij = torch.ones(self.k, self.n)
    self.mvnj = torch.zeros(self.k, self.n) 
    self.pred = torch.zeros(self.n)
    #muj, coj, rij, mvnj

  def train(self, itermax=100):
    
    self._init() 
    self.gmmKmeansInitial()
    
    llv = [0.0]
    for iter in range(itermax):
      lli = self.estep()
      self.mstep()
      
      print('Iteration', iter + 1, 'Likelihood: ', lli)
        
      if abs(llv[-1] - lli) < 1e-4:
        break 
      
      llv.append(lli)
    return llv[1:]

  def estep(self):
    #mvnj = torch.zeros(k, X.shape[0]) 
    for j in range(self.k):
      self.mvnj[j] = (self.wj[j] * stats.multivariate_normal(self.muj[j], self.coj[j]).pdf(self.X))
      
    bot = torch.sum(self.mvnj, 0)
    self.rij = self.mvnj / bot
    logl = torch.sum(torch.log(bot))
      
    self.pred = torch.max(self.rij, 0).indices
    
    return logl

  def mstep(self):

    d = torch.from_numpy(np.array(self.X))
    
    self.nj = torch.sum(self.rij, 1)
    self.wj = self.nj / sum(self.nj) ## same as nj / n
  
    for j in range(self.k):
      self.muj[j] = torch.sum(d.T * self.rij[j], 1) / self.nj[j]

    for j in range(self.k):
      ker = (d - self.muj[j]).T * self.rij[j]
      self.coj[j] = torch.matmul(ker, ker.T)/self.nj[j]

  def gmmKmeansInitial(self):
    
    kmeans = KMeans(n_clusters=self.k, random_state=0).fit(self.X)
    self.muj = torch.tensor(kmeans.cluster_centers_)
  
    for j in range(self.k):
      indx = kmeans.labels_ == j
      d = self.X[indx]
      self.nj[j] = sum(indx)
      #muj.append(np.array(d.mean(axis=0)))
      self.coj[j] = torch.from_numpy(np.cov(d.T))
    
    self.wj = self.nj / torch.sum(self.nj)

In [2]:
import torch
import numpy as np
import scipy.stats as stats ## for mvn
from sklearn.cluster import KMeans
import seaborn as sns

In [5]:
## try iris data
k = 3
iris = sns.load_dataset("iris")
X = iris.iloc[:,:4]
y = iris.iloc[:,4]

gmm = GMM(X)
lls = gmm.train()

Iteration 1 Likelihood:  tensor(-197.3202)
Iteration 2 Likelihood:  tensor(-193.0827)
Iteration 3 Likelihood:  tensor(-191.4538)
Iteration 4 Likelihood:  tensor(-189.7725)
Iteration 5 Likelihood:  tensor(-188.4172)
Iteration 6 Likelihood:  tensor(-186.6134)
Iteration 7 Likelihood:  tensor(-184.5180)
Iteration 8 Likelihood:  tensor(-183.0643)
Iteration 9 Likelihood:  tensor(-182.3152)
Iteration 10 Likelihood:  tensor(-181.4987)
Iteration 11 Likelihood:  tensor(-180.6420)
Iteration 12 Likelihood:  tensor(-180.4911)
Iteration 13 Likelihood:  tensor(-180.4788)
Iteration 14 Likelihood:  tensor(-180.4806)
Iteration 15 Likelihood:  tensor(-180.4832)
Iteration 16 Likelihood:  tensor(-180.4847)
Iteration 17 Likelihood:  tensor(-180.4854)
Iteration 18 Likelihood:  tensor(-180.4857)
Iteration 19 Likelihood:  tensor(-180.4858)
Iteration 20 Likelihood:  tensor(-180.4859)


In [6]:
gmm.pred

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2,
        0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0])

In [8]:
gmm.rij.shape

torch.Size([3, 150])

In [8]:
from sklearn.metrics import confusion_matrix
targets = torch.cat((torch.tensor([1]).repeat(50), torch.tensor([2]).repeat(50), torch.tensor([0]).repeat(50)))
confusion_matrix(targets, gmm.pred)

array([[50,  0,  0],
       [ 0, 50,  0],
       [ 5,  0, 45]])

In [87]:
import sklearn.datasets

## generate data
n = 1000; p = 15; k = 5 #np.random.randint(2, 7)
#XX = torch.tensor(sklearn.datasets.make_spd_matrix(n, p))

In [88]:
## generate random data
ww = torch.rand(k); wj = ww/sum(ww)
muj = torch.zeros((k, p))
coj = torch.zeros((k, p, p))

In [90]:
for j in range(k):
  ## means
  muj[j] = 10 * torch.randn(p)

  ## covariance matrix
  mat = torch.rand(p, p)
  mat = torch.mm(mat, mat.t())
  coj[j].add_(torch.eye(p))

In [91]:
samples = np.zeros((n, p+1))
u = np.random.uniform(size=n)

In [92]:
for i in range(n):
  for j in range(k):
    if u[i] < sum(wj[:(j+1)]):
      samples[i] = np.append(np.random.multivariate_normal(muj[j], coj[j], 1), [j])
      break

In [93]:
tgmm = GMM(samples[:,:p], k)

1000
15


In [94]:
tlls = tgmm.train()

Iteration 1 Likelihood:  tensor(-27722.5117)
Iteration 2 Likelihood:  tensor(-27722.3984)
Iteration 3 Likelihood:  tensor(-27722.3984)


In [95]:
tlls

[tensor(-27722.5117), tensor(-27722.3984)]

In [96]:
np.unique(tgmm.pred)

array([0, 1, 2, 3, 4])