Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
54 changed files
with
13,038 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# __all__ = something | ||
import pyhsmm | ||
import pyhsmm.models | ||
import pyhsmm.basic | ||
import pyhsmm.basic.distributions as distributions # shortcut | ||
import pyhsmm.plugins | ||
import pyhsmm.util | ||
|
||
import os | ||
EIGEN_INCLUDE_DIR = os.path.join(os.path.dirname(__file__), 'deps/Eigen3') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import models | ||
import distributions | ||
import abstractions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from __future__ import division | ||
import abc | ||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
|
||
from pybasicbayes.abstractions import * | ||
from ..util.stats import flattendata, sample_discrete, sample_discrete_from_log, combinedata | ||
from ..util.general import rcumsum | ||
|
||
class DurationDistribution(Distribution): | ||
__metaclass__ = abc.ABCMeta | ||
|
||
# in addition to the methods required by Distribution, we also require a | ||
# log_sf implementation | ||
|
||
@abc.abstractmethod | ||
def log_sf(self,x): | ||
''' | ||
log survival function, defined by log_sf(x) = log(P[X \gt x]) = | ||
log(1-cdf(x)) where cdf(x) = P[X \leq x] | ||
''' | ||
pass | ||
|
||
def log_pmf(self,x): | ||
return self.log_likelihood(x) | ||
|
||
def expected_log_pmf(self,x): | ||
return self.expected_log_likelihood(x) | ||
|
||
# default implementations below | ||
|
||
def pmf(self,x): | ||
return np.exp(self.log_pmf(x)) | ||
|
||
def rvs_given_greater_than(self,x): | ||
tail = self.log_sf(x) | ||
|
||
# if numerical underflow, return anything sensible | ||
if np.isinf(tail): | ||
return x+1 | ||
|
||
# if big tail, rejection sample | ||
elif np.exp(tail) > 0.1: | ||
y = self.rvs(25) | ||
while not np.any(y > x): | ||
y = self.rvs(25) | ||
return y[y > x][0] | ||
|
||
# otherwise, sample directly using the pmf and sf | ||
else: | ||
u = np.random.rand() | ||
y = x | ||
while u > 0: | ||
u -= np.exp(self.log_pmf(y) - tail) | ||
y += 1 | ||
return y | ||
|
||
def rvs_given_less_than(self,x,num): | ||
pmf = self.pmf(np.arange(1,x)) | ||
return sample_discrete(pmf,num)+1 | ||
|
||
def expected_log_sf(self,x): | ||
x = np.atleast_1d(x).astype('int32') | ||
assert x.ndim == 1 | ||
inf = max(2*x.max(),2*1000) # approximately infinity, we hope | ||
return rcumsum(self.expected_log_pmf(np.arange(1,inf)),strict=True)[x] | ||
|
||
def resample_with_censoring(self,data=[],censored_data=[]): | ||
''' | ||
censored_data is full of observations that were censored, meaning a | ||
value of x really could have been anything >= x, so this method samples | ||
them out to be at least that large | ||
''' | ||
filled_in = self._uncensor_data(censored_data) | ||
return self.resample(data=combinedata((data,filled_in))) | ||
|
||
def _uncensor_data(self,censored_data): | ||
# TODO numpy-vectorize this! | ||
if len(censored_data) > 0: | ||
if not isinstance(censored_data,list): | ||
filled_in = np.asarray([self.rvs_given_greater_than(x-1) | ||
for x in censored_data]) | ||
else: | ||
filled_in = np.asarray([self.rvs_given_greater_than(x-1) | ||
for xx in censored_data for x in xx]) | ||
else: | ||
filled_in = [] | ||
return filled_in | ||
|
||
def resample_with_censoring_and_truncation(self,data=[],censored_data=[],left_truncation_level=None): | ||
filled_in = self._uncensor_data(censored_data) | ||
|
||
if left_truncation_level is not None and left_truncation_level > 1: | ||
norm = self.pmf(np.arange(1,left_truncation_level)).sum() | ||
num_rejected = np.random.geometric(1-norm)-1 | ||
rejected_observations = self.rvs_given_less_than(left_truncation_level,num_rejected) \ | ||
if num_rejected > 0 else [] | ||
else: | ||
rejected_observations = [] | ||
|
||
self.resample(data=combinedata((data,filled_in,rejected_observations))) | ||
|
||
@property | ||
def mean(self): | ||
# TODO this is dumb, why is this here? | ||
trunc = 500 | ||
while self.log_sf(trunc) > -20: | ||
trunc *= 1.5 | ||
return np.arange(1,trunc+1).dot(self.pmf(np.arange(1,trunc+1))) | ||
|
||
def plot(self,data=None,color='b',**kwargs): | ||
data = flattendata(data) if data is not None else None | ||
|
||
try: | ||
tmax = np.where(np.exp(self.log_sf(np.arange(1,1000))) < 1e-3)[0][0] | ||
except IndexError: | ||
tmax = 2*self.rvs(1000).mean() | ||
tmax = max(tmax,data.max()) if data is not None else tmax | ||
|
||
t = np.arange(1,tmax+1) | ||
plt.plot(t,self.pmf(t),color=color) | ||
|
||
if data is not None: | ||
if len(data) > 1: | ||
plt.hist(data,bins=t-0.5,color=color,normed=len(set(data)) > 1) | ||
else: | ||
plt.hist(data,bins=t-0.5,color=color) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
from __future__ import division | ||
import numpy as np | ||
import scipy.stats as stats | ||
import scipy.special as special | ||
|
||
from pybasicbayes.distributions import * | ||
from pybasicbayes.models import MixtureDistribution | ||
from abstractions import DurationDistribution | ||
|
||
############################################## | ||
# Mixins for making duratino distributions # | ||
############################################## | ||
|
||
class _StartAtOneMixin(object): | ||
def log_likelihood(self,x,*args,**kwargs): | ||
return super(_StartAtOneMixin,self).log_likelihood(x-1,*args,**kwargs) | ||
|
||
def log_sf(self,x,*args,**kwargs): | ||
return super(_StartAtOneMixin,self).log_sf(x-1,*args,**kwargs) | ||
|
||
def expected_log_likelihood(self,x,*args,**kwargs): | ||
return super(_StartAtOneMixin,self).expected_log_likelihood(x-1,*args,**kwargs) | ||
|
||
def rvs(self,size=None): | ||
return super(_StartAtOneMixin,self).rvs(size)+1 | ||
|
||
def rvs_given_greater_than(self,x): | ||
return super(_StartAtOneMixin,self).rvs_given_greater_than(x)+1 | ||
|
||
def resample(self,data=[],*args,**kwargs): | ||
if isinstance(data,np.ndarray): | ||
return super(_StartAtOneMixin,self).resample(data-1,*args,**kwargs) | ||
else: | ||
return super(_StartAtOneMixin,self).resample([d-1 for d in data],*args,**kwargs) | ||
|
||
def max_likelihood(self,data,weights=None,*args,**kwargs): | ||
if isinstance(data,np.ndarray): | ||
return super(_StartAtOneMixin,self).max_likelihood( | ||
data-1,weights=weights,*args,**kwargs) | ||
else: | ||
return super(_StartAtOneMixin,self).max_likelihood( | ||
[d-1 for d in data],weights=weights,*args,**kwargs) | ||
|
||
def meanfieldupdate(self,data,weights,*args,**kwargs): | ||
if isinstance(data,np.ndarray): | ||
return super(_StartAtOneMixin,self).meanfieldupdate( | ||
data-1,weights=weights,*args,**kwargs) | ||
else: | ||
return super(_StartAtOneMixin,self).meanfieldupdate( | ||
[d-1 for d in data],weights=weights,*args,**kwargs) | ||
|
||
def meanfield_sgdstep(self,data,weights,minibatchfrac,stepsize): | ||
if isinstance(data,np.ndarray): | ||
return super(_StartAtOneMixin,self).meanfield_sgdstep( | ||
data-1,weights=weights, | ||
minibatchfrac=minibatchfrac,stepsize=stepsize) | ||
else: | ||
return super(_StartAtOneMixin,self).meanfield_sgdstep( | ||
[d-1 for d in data],weights=weights, | ||
minibatchfrac=minibatchfrac,stepsize=stepsize) | ||
|
||
########################## | ||
# Distribution classes # | ||
########################## | ||
|
||
class GeometricDuration( | ||
Geometric, | ||
DurationDistribution): | ||
pass | ||
|
||
class PoissonDuration( | ||
_StartAtOneMixin, | ||
Poisson, | ||
DurationDistribution): | ||
pass | ||
|
||
class NegativeBinomialDuration( | ||
_StartAtOneMixin, | ||
NegativeBinomial, | ||
DurationDistribution): | ||
pass | ||
|
||
class NegativeBinomialFixedRDuration( | ||
_StartAtOneMixin, | ||
NegativeBinomialFixedR, | ||
DurationDistribution): | ||
pass | ||
|
||
class NegativeBinomialIntegerRDuration( | ||
_StartAtOneMixin, | ||
NegativeBinomialIntegerR, | ||
DurationDistribution): | ||
pass | ||
|
||
class NegativeBinomialIntegerR2Duration( | ||
_StartAtOneMixin, | ||
NegativeBinomialIntegerR2, | ||
DurationDistribution): | ||
pass | ||
|
||
class NegativeBinomialFixedRVariantDuration( | ||
NegativeBinomialFixedRVariant, | ||
DurationDistribution): | ||
pass | ||
|
||
class NegativeBinomialIntegerRVariantDuration( | ||
NegativeBinomialIntegerRVariant, | ||
DurationDistribution): | ||
pass | ||
|
||
################# | ||
# Model stuff # | ||
################# | ||
|
||
# this is extending the MixtureDistribution from basic/pybasicbayes/models.py | ||
# and then clobbering the name | ||
class MixtureDistribution(MixtureDistribution, DurationDistribution): | ||
# TODO test this | ||
def log_sf(self,x): | ||
x = np.asarray(x,dtype=np.float64) | ||
K = len(self.components) | ||
vals = np.empty((x.shape[0],K)) | ||
for idx, c in enumerate(self.components): | ||
vals[:,idx] = c.log_sf(x) | ||
vals += self.weights.log_likelihood(np.arange(K)) | ||
return np.logaddexp.reduce(vals,axis=1) | ||
|
||
########## | ||
# Meta # | ||
########## | ||
|
||
# this class is for delaying instances of duration distributions | ||
class Delay(DurationDistribution): | ||
def __init__(self,dur_distn,delay): | ||
self.dur_distn = dur_distn | ||
self.delay = delay | ||
|
||
def log_sf(self,x): | ||
return self.dur_distn.log_sf(x-self.delay) | ||
|
||
def log_likelihood(self,x): | ||
return self.dur_distn.log_likelihood(x-self.delay) | ||
|
||
def rvs(self,size=None): | ||
return self.dur_distn.rvs(size) + self.delay | ||
|
||
def resample(self,data=[],*args,**kwargs): | ||
if isinstance(data,np.ndarray): | ||
return self.dur_distn.resample(data-self.delay,*args,**kwargs) | ||
else: | ||
return self.dur_distn.resample([d-self.delay for d in data],*args,**kwargs) | ||
|
||
def max_likelihood(self,*args,**kwargs): | ||
raise NotImplementedError | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# These classes make aliases of class members and properties so as to make | ||
# pybasicbayes mixture models look more like pyhsmm models. When comparing | ||
# H(S)MM model fits to pybasicbayes mixture model fits, it's easier to write one | ||
# code path by using these models. | ||
|
||
from copy import deepcopy | ||
|
||
import pybasicbayes | ||
from ..util.general import rle | ||
|
||
class _Labels(pybasicbayes.internals.labels.Labels): | ||
@property | ||
def T(self): | ||
return self.N | ||
|
||
@property | ||
def stateseq(self): | ||
return self.z | ||
|
||
@stateseq.setter | ||
def stateseq(self,stateseq): | ||
self.z = stateseq | ||
|
||
@property | ||
def stateseqs_norep(self): | ||
return rle(self.z)[0] | ||
|
||
@property | ||
def durations(self): | ||
return rle(self.z)[1] | ||
|
||
class _MixturePropertiesMixin(object): | ||
_labels_class = _Labels | ||
|
||
@property | ||
def num_states(self): | ||
return len(self.obs_distns) | ||
|
||
@property | ||
def states_list(self): | ||
return self.labels_list | ||
|
||
@property | ||
def stateseqs(self): | ||
return [s.stateseq for s in self.states_list] | ||
|
||
@property | ||
def stateseqs_norep(self): | ||
return [s.stateseq_norep for s in self.states_list] | ||
|
||
@property | ||
def durations(self): | ||
return [s.durations for s in self.states_list] | ||
|
||
@property | ||
def obs_distns(self): | ||
return self.components | ||
|
||
@obs_distns.setter | ||
def obs_distns(self,distns): | ||
self.components = distns | ||
|
||
def predict(self,seed_data,timesteps,**kwargs): | ||
# NOTE: seed_data doesn't matter! | ||
return self.generate(timesteps,keep=False) | ||
|
||
@classmethod | ||
def from_pbb_mixture(cls,mixture): | ||
self = cls( | ||
weights_obj=deepcopy(mixture.weights), | ||
components=deepcopy(mixture.components)) | ||
for l in mixture.labels_list: | ||
self.add_data(l.data,z=l.z) | ||
return self | ||
|
||
class Mixture(_MixturePropertiesMixin,pybasicbayes.models.Mixture): | ||
pass | ||
|
||
class MixtureDistribution(_MixturePropertiesMixin,pybasicbayes.models.MixtureDistribution): | ||
pass | ||
|
Oops, something went wrong.