Skip to content

Commit

Permalink
population: move duplicate properties to mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
rpauszek committed Jan 24, 2024
1 parent 189d513 commit 0762c09
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 88 deletions.
18 changes: 8 additions & 10 deletions lumicks/pylake/population/detail/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np

from .mixin import LatentVariableModel
from ..mixture import GaussianMixtureModel
from .fit_info import PopulationFitInfo
from ...channel import Slice, Continuous
Expand All @@ -23,28 +24,25 @@ def normalize_rows(matrix):


@dataclass(frozen=True)
class ClassicHmm:
class ClassicHmm(LatentVariableModel):
"""Model parameters for classic Hidden Markov Model.
Parameters
----------
K : int
number of states
pi : np.ndarray
initial state probabilities, shape [K, ]
A : np.ndarray
state transition probability matrix, shape [K, K]
mu : np.ndarray
state means, shape [K, ]
tau : np.ndarray
state precision (1 / variance), shape [K, ]
pi : np.ndarray
initial state probabilities, shape [K, ]
A : np.ndarray
state transition probability matrix, shape [K, K]
"""

K: int
pi: np.ndarray
A: np.ndarray
mu: np.ndarray
tau: np.ndarray

@classmethod
def guess(cls, data, n_states, gmm=None):
Expand Down Expand Up @@ -76,7 +74,7 @@ def guess(cls, data, n_states, gmm=None):
)
pi = np.ones(n_states) / n_states

return cls(n_states, pi, A, gmm.means, 1 / gmm.variances)
return cls(n_states, gmm.means, 1 / gmm.variances, pi, A)

def state_log_likelihood(self, x):
"""Calculate the state likelihood of the observation data `x`. Work in log space to avoid
Expand All @@ -103,7 +101,7 @@ def update(self, data, gamma, xi):
x_bar = np.sum(gamma * col(data), axis=0) / gamma.sum(axis=0) # Eq 53
variance = np.sum(gamma * (col(data) - row(x_bar)) ** 2, axis=0) / gamma.sum(axis=0) # Eq54

return ClassicHmm(self.K, pi, A, x_bar, 1 / variance)
return ClassicHmm(self.K, x_bar, 1 / variance, pi, A)


def baum_welch(data, model, tol, max_iter):
Expand Down
30 changes: 30 additions & 0 deletions lumicks/pylake/population/detail/mixin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
from dataclasses import dataclass

import numpy as np

from .fit_info import PopulationFitInfo
from ...channel import Slice
from ..dwelltime import _dwellcounts_from_statepath


class TimeSeriesMixin:
@property
def fit_info(self) -> PopulationFitInfo:
"""Information about the model training exit conditions."""
return self._fit_info

@property
def means(self) -> np.ndarray:
"""Model state means."""
return self._model.mu

@property
def variances(self) -> np.ndarray:
"""Model state variances."""
return 1 / self._model.tau

@property
def std(self) -> np.ndarray:
"""Model state standard deviations."""
return np.sqrt(self.variances)

def extract_dwell_times(self, trace, *, exclude_ambiguous_dwells=True):
"""Calculate lists of dwelltimes for each state in a time-ordered state path array.
Expand Down Expand Up @@ -151,3 +174,10 @@ def plot_path(self, trace, *, trace_kwargs=None, path_kwargs=None):
path_kwargs = {"c": "tab:blue", "lw": 2, **(path_kwargs or {})}
emission_path = self.emission_path(trace)
emission_path.plot(**path_kwargs)


@dataclass(frozen=True)
class LatentVariableModel:
K: int
mu: np.ndarray
tau: np.ndarray
21 changes: 0 additions & 21 deletions lumicks/pylake/population/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from ..channel import Slice
from .detail.hmm import ClassicHmm, viterbi, baum_welch
from .detail.mixin import TimeSeriesMixin
from .detail.fit_info import PopulationFitInfo
from .detail.validators import col


Expand Down Expand Up @@ -66,11 +65,6 @@ def __init__(self, data, n_states, *, tol=1e-3, max_iter=250, initial_guess=None

self._model, self._fit_info = baum_welch(data, initial_guess, tol=tol, max_iter=max_iter)

@property
def fit_info(self) -> PopulationFitInfo:
"""Information about the model training exit conditions."""
return self._fit_info

@property
def initial_state_probability(self) -> np.ndarray:
"""Model initial state probability."""
Expand All @@ -85,21 +79,6 @@ def transition_matrix(self) -> np.ndarray:
"""
return self._model.A

@property
def means(self) -> np.ndarray:
"""Model state means."""
return self._model.mu

@property
def variances(self) -> np.ndarray:
"""Model state variances."""
return 1 / self._model.tau

@property
def std(self) -> np.ndarray:
"""Model state standard deviations."""
return np.sqrt(self.variances)

def _calculate_state_path(self, trace):
return viterbi(trace.data, self._model)

Expand Down
90 changes: 38 additions & 52 deletions lumicks/pylake/population/mixture.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from dataclasses import dataclass

import numpy as np
import scipy
from deprecated.sphinx import deprecated

from ..channel import Slice
from .detail.mixin import TimeSeriesMixin
from .detail.mixin import TimeSeriesMixin, LatentVariableModel
from .detail.fit_info import PopulationFitInfo


def as_sorted(fcn):
"""Decorator to return results sorted according to mapping array.
@dataclass(frozen=True)
class ClassicGmm(LatentVariableModel):
"""Model parameters for classic Gaussian Mixture Model.
To be used as a method decorator in a class that supplies an index
mapping array via the `._map` attribute.
Parameters
----------
K : int
number of states
mu : np.ndarray
state means, shape [K, ]
tau : np.ndarray
state precision (1 / variance), shape [K, ]
weights: np.ndarray
state fractional weights
"""

def wrapper(self, *args, **kwargs) -> np.ndarray:
result = fcn(self, *args, **kwargs)
return result[self._map]

wrapper.__doc__ = fcn.__doc__

return wrapper
weights: np.ndarray


class GaussianMixtureModel(TimeSeriesMixin):
Expand Down Expand Up @@ -62,22 +67,33 @@ def __init__(self, data, n_states, init_method="kmeans", n_init=1, tol=1e-3, max
data = data.data

self.n_states = n_states
self._model = GaussianMixture(
model = GaussianMixture(
n_components=n_states,
init_params=init_method,
n_init=n_init,
tol=tol,
max_iter=max_iter,
)
data = np.reshape(data, (-1, 1))
self._model.fit(data)
model.fit(data)

# todo: remove when exit_flag is removed
self._deprecated_lower_bound = model.lower_bound_

idx = np.argsort(model.means_.squeeze())
self._model = ClassicGmm(
K=n_states,
mu=model.means_.squeeze()[idx],
tau=1 / model.covariances_.squeeze()[idx],
weights=model.weights_[idx],
)

self._fit_info = PopulationFitInfo(
self._model.converged_,
self._model.n_iter_,
self._model.bic(data),
self._model.aic(data),
np.sum(self._model.score_samples(data)),
converged=model.converged_,
n_iter=model.n_iter_,
bic=model.bic(data),
aic=model.aic(data),
log_likelihood=np.sum(model.score_samples(data)),
)

@classmethod
Expand Down Expand Up @@ -109,41 +125,13 @@ def exit_flag(self) -> dict:
return {
"converged": self.fit_info.converged,
"n_iter": self.fit_info.n_iter,
"lower_bound": self._model.lower_bound_,
"lower_bound": self._deprecated_lower_bound,
}

@property
def fit_info(self) -> PopulationFitInfo:
"""Information about the model training exit conditions."""
return self._fit_info

@property
def _map(self) -> np.ndarray:
"""Indices of sorted means."""
return np.argsort(self._model.means_.squeeze())

@property
@as_sorted
def weights(self):
"""Model state weights."""
return self._model.weights_

@property
@as_sorted
def means(self):
"""Model state means."""
return self._model.means_.squeeze()

@property
@as_sorted
def variances(self):
"""Model state variances."""
return self._model.covariances_.squeeze()

@property
def std(self) -> np.ndarray:
"""Model state standard deviations."""
return np.sqrt(self.variances)
return self._model.weights

@deprecated(
reason=(
Expand All @@ -164,9 +152,7 @@ def label(self, trace):
return self.state_path(trace).data

def _calculate_state_path(self, trace):
labels = self._model.predict(trace.data.reshape((-1, 1))) # wrapped model labels
output_states = np.argsort(self._map) # output model state labels in wrapped model order
return output_states[labels] # output model labels
return np.argmax(self.pdf(trace.data), axis=0)

@property
@deprecated(
Expand Down
10 changes: 5 additions & 5 deletions lumicks/pylake/population/tests/test_hmm_algos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from lumicks.pylake import HiddenMarkovModel, GaussianMixtureModel
from lumicks.pylake import GaussianMixtureModel
from lumicks.pylake.population.detail.hmm import (
ClassicHmm,
viterbi,
Expand Down Expand Up @@ -118,20 +118,20 @@ def test_viterbi_with_zeros(trace_simple):
data, _, params = trace_simple
model = ClassicHmm(
params["n_states"],
[1, 0],
normalize_rows(params["transition_prob"]),
params["means"],
params["st_devs"],
[1, 0],
normalize_rows(params["transition_prob"]),
)

viterbi(data, model)

model = ClassicHmm(
params["n_states"],
params["initial_state_prob"],
[[1, 0], [0.1, 0.9]],
params["means"],
params["st_devs"],
params["initial_state_prob"],
[[1, 0], [0.1, 0.9]],
)

viterbi(data, model)

0 comments on commit 0762c09

Please sign in to comment.