Skip to content

Commit

Permalink
added example
Browse files Browse the repository at this point in the history
  • Loading branch information
ldeecke committed Jul 30, 2018
1 parent cdffb99 commit 351874d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
10 changes: 8 additions & 2 deletions README.md
@@ -1,5 +1,11 @@
This repository contains an implementation of a simple **Gaussian mixture model** (GMM) fitted with Expectation-Maximization in [pytorch](http://www.pytorch.org). The interface closely follows that of [sklearn](http://scikit-learn.org).

A new model is instantiated by calling `m = gmm.GaussianMixture(n_components, d)`. Once instantiated, the model expects tensors in a flattened shape `(n, d)`. Predicting class memberships is straightforward, first fit the model via `m.fit(data)`, then predict with `m.predict(data)`.
![Example of a fit via a Gaussian Mixture model.](example.png)

Some sanity checks can be executed by calling `python test.py`. To handle data on GPUs, ensure that `m.cuda()` is called.
---

A new model is instantiated by calling `gmm.GaussianMixture(..)` and providing as arguments the number of components, as well as the tensor dimension. Note that once instantiated, the model expects tensors in a flattened shape `(n, d)`.

The first step would usually be to fit the model via `model.fit(data)`, then predict with `model.predict(data)`. To reproduce the above figure, just run the provided `example.py`.

Some sanity checks can be executed by calling `python test.py`. To fit data on GPUs, ensure that you first call `model.cuda()`.
Binary file added example.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
53 changes: 53 additions & 0 deletions example.py
@@ -0,0 +1,53 @@
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="white", font="Arial")
colors = sns.color_palette("Paired", n_colors=12).as_hex()

import numpy as np
import torch

from gmm import GaussianMixture
from math import sqrt


def main():
n, d = 400, 2

# generate some data points ..
data = torch.Tensor(n, d).normal_()
# .. as well as a random partition that ..
ids = np.random.choice(n, n//2, replace=False)
# .. is permuted to come from a non-standard Gaussian N(7, 16)
data[:n//2] -= 1
data[:n//2] *= sqrt(3)
data[n//2:] += 1
data[n//2:] *= sqrt(2)

# a Gaussian Mixture Model is instantiated and ..
n_components = 2
model = GaussianMixture(n_components, d)
model.fit(data)
# .. used to predict the data points that where shifted
y = model.predict(data)
c = np.isin(np.where(y > 0), ids)

fig, ax = plt.subplots(1, 1, figsize=(1.61803398875*4, 4))
ax.set_facecolor('#bbbbbb')
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# plot the locations of all data points ..
ax.scatter(*data[:n//2].data.t(), color="#000000", s=3, alpha=.75, label="Ground-truth 1")
ax.scatter(*data[n//2:].data.t(), color="#ffffff", s=3, alpha=.75, label="Ground-truth 2")

# .. and circle them according to their classification
ax.scatter(*data[np.where(y == 0)].data.t(), zorder=0, color="#dbe9ff", alpha=.6, edgecolors=colors[1], label="Predicted 1")
ax.scatter(*data[np.where(y == 1)].data.t(), zorder=0, color="#ffdbdb", alpha=.6, edgecolors=colors[5], label="Predicted 2")

ax.legend(loc="best")
plt.tight_layout()
plt.savefig("example.pdf")

if __name__ == "__main__":
main()

0 comments on commit 351874d

Please sign in to comment.