Skip to content

Commit

Permalink
Add GMMRegression scikit-learn RegressorMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
mralbu committed Apr 3, 2021
1 parent d665759 commit 68cf2f1
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 2 deletions.
2 changes: 1 addition & 1 deletion gmr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
__all__ = ["gmm", "mvn", "utils"]

from .mvn import MVN, plot_error_ellipse
from .gmm import (GMM, plot_error_ellipses, kmeansplusplus_initialization,
from .gmm import (GMM, GMMRegression, plot_error_ellipses, kmeansplusplus_initialization,
covariance_initialization)

__all__.extend(["MVN", "plot_error_ellipse", "GMM", "plot_error_ellipses",
Expand Down
78 changes: 78 additions & 0 deletions gmr/gmm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from scipy.spatial.distance import cdist, pdist
from scipy.stats import chi2
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
from .utils import check_random_state
from .mvn import MVN

Expand Down Expand Up @@ -489,6 +490,83 @@ def extract_mvn(self, component_idx):
covariance=self.covariances[component_idx], verbose=self.verbose,
random_state=self.random_state)

class GMMRegression(MultiOutputMixin, RegressorMixin, BaseEstimator):
"""Scikit-learn RegressorMixin for the GMM class.
Parameters
----------
n_components : int (> 0)
Number of MVNs that compose the GMM.
random_state : int or RandomState, optional (default: global random state)
If an integer is given, it fixes the seed. Defaults to the global numpy
random number generator.
R_diff : float
Minimum allowed difference of responsibilities between successive
EM iterations.
n_iter : int
Maximum number of iterations.
init_params : str, optional (default: 'random')
Parameter initialization strategy. If means and covariances are
given in the constructor, this parameter will have no effect.
'random' will sample initial means randomly from the dataset
and set covariances to identity matrices. This is the
computationally cheap solution.
'kmeans++' will use k-means++ initialization for means and
initialize covariances to diagonal matrices with variances
set based on the average distances of samples in each dimensions.
This is computationally more expensive but often gives much
better results.
Returns
-------
self : GMMRegression
This object.
"""

def __init__(self, n_components, priors=None, means=None, covariances=None,
verbose=0, random_state=None, R_diff=1e-4, n_iter=500, init_params="random"):
self.n_components = n_components
self.priors = priors
self.means = means
self.covariances = covariances
self.verbose = verbose
self.random_state = random_state
self.R_diff = R_diff
self.n_iter = n_iter
self.init_params = init_params

def fit(self, X, y):
self.gmm = GMM(self.n_components, priors=self.priors, means=self.means,
covariances=self.covariances, verbose=self.verbose, random_state=self.random_state)

if y.ndim > 2:
raise ValueError("y must have at most two dimensions.")
elif y.ndim == 1:
y = np.expand_dims(y, 1)

if X.ndim > 2:
raise ValueError("y must have at most two dimensions.")
elif X.ndim == 1:
X = np.expand_dims(X, 1)

self._indices = np.arange(X.shape[1])

self.gmm.from_samples(np.hstack((X, y)),
R_diff=self.R_diff, n_iter=self.n_iter, init_params=self.init_params)
return self

def predict(self, X):
if X.ndim > 2:
raise ValueError("y must have at most two dimensions.")
elif X.ndim == 1:
X = np.expand_dims(X, 1)

return self.gmm.predict(self._indices, X)

def plot_error_ellipses(ax, gmm, colors=None, alpha=0.25, factors=np.linspace(0.25, 2.0, 8)):
"""Plot error ellipses of GMM components.
Expand Down
29 changes: 28 additions & 1 deletion gmr/tests/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
except ImportError:
# Python 3
from io import StringIO
from gmr import GMM, MVN, plot_error_ellipses, kmeansplusplus_initialization, covariance_initialization
from gmr import GMM, MVN, GMMRegression, plot_error_ellipses, kmeansplusplus_initialization, covariance_initialization
from test_mvn import AxisStub


Expand Down Expand Up @@ -251,6 +251,22 @@ def test_regression_with_2d_input():
pred = gmm.predict(np.array([0, 1]), np.hstack((x, x[::-1])))
mse = np.sum((y - pred) ** 2) / n_samples

random_state = check_random_state(0)

n_samples = 200
x = np.linspace(0, 2, n_samples)[:, np.newaxis]
y1 = 3 * x[:n_samples // 2] + 1
y2 = -3 * x[n_samples // 2:] + 7
noise = random_state.randn(n_samples, 1) * 0.01
y = np.vstack((y1, y2)) + noise
samples = np.hstack((x, x[::-1], y))

gmm = GMMRegression(n_components=2, random_state=random_state)
gmm.fit(np.hstack((x, x[::-1])), y)

pred = gmm.predict(np.hstack((x, x[::-1])))
mse = np.sum((y - pred) ** 2) / n_samples


def test_regression_without_noise():
"""Test regression without noise."""
Expand All @@ -273,6 +289,17 @@ def test_regression_without_noise():
mse = np.sum((y - pred) ** 2) / n_samples
assert_less(mse, 0.01)

random_state = check_random_state(0)

gmm = GMMRegression(n_components=2, random_state=random_state)
gmm.fit(x, y)
assert_array_almost_equal(gmm.gmm.priors, 0.5 * np.ones(2), decimal=2)
assert_array_almost_equal(gmm.gmm.means[0], np.array([1.5, 2.5]), decimal=2)
assert_array_almost_equal(gmm.gmm.means[1], np.array([0.5, 2.5]), decimal=1)

pred = gmm.predict(x)
mse = np.sum((y - pred) ** 2) / n_samples
assert_less(mse, 0.01)

def test_plot():
"""Test plot of GMM."""
Expand Down

0 comments on commit 68cf2f1

Please sign in to comment.