Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
61 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
This repository contains an implementation of a simple **Gaussian mixture model** (GMM) fitted with Expectation-Maximization in [pytorch](http://www.pytorch.org). The interface closely follows that of [sklearn](http://scikit-learn.org). | ||
|
||
A new model is instantiated by calling `m = gmm.GaussianMixture(n_components, d)`. Once instantiated, the model expects tensors in a flattened shape `(n, d)`. Predicting class memberships is straightforward, first fit the model via `m.fit(data)`, then predict with `m.predict(data)`. | ||
![Example of a fit via a Gaussian Mixture model.](example.png) | ||
|
||
Some sanity checks can be executed by calling `python test.py`. To handle data on GPUs, ensure that `m.cuda()` is called. | ||
--- | ||
|
||
A new model is instantiated by calling `gmm.GaussianMixture(..)` and providing as arguments the number of components, as well as the tensor dimension. Note that once instantiated, the model expects tensors in a flattened shape `(n, d)`. | ||
|
||
The first step would usually be to fit the model via `model.fit(data)`, then predict with `model.predict(data)`. To reproduce the above figure, just run the provided `example.py`. | ||
|
||
Some sanity checks can be executed by calling `python test.py`. To fit data on GPUs, ensure that you first call `model.cuda()`. |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
sns.set(style="white", font="Arial") | ||
colors = sns.color_palette("Paired", n_colors=12).as_hex() | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from gmm import GaussianMixture | ||
from math import sqrt | ||
|
||
|
||
def main(): | ||
n, d = 400, 2 | ||
|
||
# generate some data points .. | ||
data = torch.Tensor(n, d).normal_() | ||
# .. as well as a random partition that .. | ||
ids = np.random.choice(n, n//2, replace=False) | ||
# .. is permuted to come from a non-standard Gaussian N(7, 16) | ||
data[:n//2] -= 1 | ||
data[:n//2] *= sqrt(3) | ||
data[n//2:] += 1 | ||
data[n//2:] *= sqrt(2) | ||
|
||
# a Gaussian Mixture Model is instantiated and .. | ||
n_components = 2 | ||
model = GaussianMixture(n_components, d) | ||
model.fit(data) | ||
# .. used to predict the data points that where shifted | ||
y = model.predict(data) | ||
c = np.isin(np.where(y > 0), ids) | ||
|
||
fig, ax = plt.subplots(1, 1, figsize=(1.61803398875*4, 4)) | ||
ax.set_facecolor('#bbbbbb') | ||
ax.set_xlabel("Dimension 1") | ||
ax.set_ylabel("Dimension 2") | ||
# plot the locations of all data points .. | ||
ax.scatter(*data[:n//2].data.t(), color="#000000", s=3, alpha=.75, label="Ground-truth 1") | ||
ax.scatter(*data[n//2:].data.t(), color="#ffffff", s=3, alpha=.75, label="Ground-truth 2") | ||
|
||
# .. and circle them according to their classification | ||
ax.scatter(*data[np.where(y == 0)].data.t(), zorder=0, color="#dbe9ff", alpha=.6, edgecolors=colors[1], label="Predicted 1") | ||
ax.scatter(*data[np.where(y == 1)].data.t(), zorder=0, color="#ffdbdb", alpha=.6, edgecolors=colors[5], label="Predicted 2") | ||
|
||
ax.legend(loc="best") | ||
plt.tight_layout() | ||
plt.savefig("example.pdf") | ||
|
||
if __name__ == "__main__": | ||
main() |