In [1]:
import numpy as np
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
import scipy as sp
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
%matplotlib auto

Using matplotlib backend: Qt5Agg


In [None]:
def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d

In [None]:
np.random.seed(0)
cov1 = np.diag((1, 2))
N1 = 500
N2 = 300
N = N1 + N2
x1 = np.random.multivariate_normal(mean=(3, 2), cov=cov1, size=N1)
m = np.array(((1, 1), (1, 3)))
x1 = x1.dot(m)
x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
x = np.vstack((x1, x2))
y = np.array([0]*N1 + [1]*N2)
n_components = 3
colors = '#A0FFA0', '#2090E0', '#FF8080'
cm = mpl.colors.ListedColormap(colors)
x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
x1_min, x1_max = expand(x1_min, x1_max)
x2_min, x2_max = expand(x2_min, x2_max)
x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
grid_test = np.stack((x1.flat, x2.flat), axis=1)


In [37]:
plt.figure(figsize=(8, 6), facecolor='w')
plt.suptitle('GMM/DPGMM compare', fontsize=28)
ax = plt.subplot(211)
gmm = GaussianMixture(n_components=n_components,
                      covariance_type='full',
                      random_state=0)
gmm.fit(x)
centers = gmm.means_
covs = gmm.covariances_

print('GMM means : \n', centers)
print('GMM covs : \n', covs)

y_hat = gmm.predict(x)
grid_hat = gmm.predict(grid_test)
grid_hat = grid_hat.reshape(x1.shape)

clrs = list('rgbmy')
colors = '#A0FFA0', '#2090E0', '#FF8080'
cm = mpl.colors.ListedColormap(colors)
plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
plt.scatter(x[:, 0], x[:, 1], s=30, c=y, cmap=cm, marker='o',edgecolors='#202020')

for i, cc in enumerate(zip(centers, covs)):
    centers, cov = cc
    value, vector = sp.linalg.eigh(cov)
    width, height = value[0], value[1]
    print('width:', width, ' height: ', height)
    print('value : ', value)
    print('vector : ', vector)
    v = vector[0] / sp.linalg.norm(vector[0])
    angle = 180 * np.arctan(v[1] / v[0]) / np.pi
    e = Ellipse(xy=centers,
                width=width,
                height=height,
                angle=angle,
                color=clrs[i],
                alpha=0.5,
                clip_box=ax.bbox)
    ax.add_artist(e)
ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.title(u'GMM', fontsize=20)

############################## DPGMM
dpgmm = BayesianGaussianMixture(
    n_components=n_components,
    covariance_type='full',
    max_iter=1000,
    n_init=5,
    weight_concentration_prior_type='dirichlet_process',
    weight_concentration_prior=10)

dpgmm.fit(x)
centers = dpgmm.means_
covs = dpgmm.covariances_
print('DPGMM means: \n', centers)
print('DPGMM covs: \n', covs)

y_hat = dpgmm.predict(x)

ax = plt.subplot(212)
grid_hat = dpgmm.predict(grid_test)
grid_hat = grid_hat.reshape(x1.shape)

plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
plt.scatter(x[:, 0], x[:, 1], c=y, s=30, cmap=cm, marker='o',edgecolors='#202020')

for i, cc in enumerate(zip(centers, covs)):
    if i not in y_hat:
        continue
    center, cov = cc
    value, vector = sp.linalg.eigh(cov)
    width, height = value[0], value[1]
    v = vector[0] / sp.linalg.norm(vector[0])
    angle = 180 * np.arctan(v[1] / v[0]) / np.pi
    print('width:', width, ' height: ', height)
    e = Ellipse(xy=center,
                width=width,
                height=height,
                angle=angle,
                alpha=0.5,
                color='m',
                clip_box=ax.bbox)

    ax.add_artist(e)

plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.title('DPGMM', fontsize=15)
plt.tight_layout()
plt.subplots_adjust(top=0.9)

plt.show()

GMM means : 
 [[ -0.98543679  10.0756839 ]
 [  6.0239399   11.61448122]
 [  3.77430768   5.86579463]]
GMM covs : 
 [[[  0.89079177  -0.02572518]
  [ -0.02572518   1.95106592]]

 [[  1.6667472    3.58655076]
  [  3.58655076  10.40673433]]

 [[  1.5383593    3.21210121]
  [  3.21210121   9.04107582]]]
width: 0.890167971365  height:  1.95168972127
value :  [ 0.89016797  1.95168972]
vector :  [[-0.99970613 -0.02424137]
 [-0.02424137  0.99970613]]
width: 0.383406337028  height:  11.6900751881
value :  [  0.38340634  11.69007519]
vector :  [[-0.9415397   0.33690207]
 [ 0.33690207  0.9415397 ]]
width: 0.351065620111  height:  10.2283695008
value :  [  0.35106562  10.2283695 ]
vector :  [[-0.93797429  0.34670481]
 [ 0.34670481  0.93797429]]
DPGMM means: 
 [[  4.87811644   8.69857678]
 [ -0.97330325  10.07291035]
 [  2.67200409   9.19129532]]
DPGMM covs: 
 [[[  2.88512056   6.60504981]
  [  6.60504981  17.92828469]]

 [[  0.96312252  -0.02851517]
  [ -0.02851517   1.98150469]]

 [[  5.10545816 