Skip to content

Commit

Permalink
Some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
glouppe committed Dec 16, 2015
1 parent 86c7b64 commit 1501dcf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
14 changes: 2 additions & 12 deletions carl/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from scipy.optimize import minimize

from sklearn.base import BaseEstimator
from sklearn.utils import check_random_state as check_random_state_sklearn
from sklearn.utils import check_random_state

from theano.gof import graph
from theano.tensor.shared_randomstreams import RandomStreams
from theano.tensor.sharedvar import SharedVariable

# ???: define the bounds of the parameters
Expand Down Expand Up @@ -47,15 +46,6 @@ def check_parameter(name, value):
return value, parameters, constants, observeds


def check_random_state_theano(random_state):
if isinstance(random_state, RandomStreams):
return random_state
elif isinstance(random_state, np.random.RandomState):
random_state = random_state.randint(np.iinfo(np.int32).max)

return RandomStreams(seed=random_state)


def bound(expression, out, *predicates):
guard = 1
for p in predicates:
Expand Down Expand Up @@ -111,7 +101,7 @@ def make_(self, expression, name, args=None, kwargs=None):
setattr(self, name, func)

def rvs(self, n_samples, **kwargs):
rng = check_random_state_sklearn(self.random_state)
rng = check_random_state(self.random_state)
p = rng.rand(n_samples, 1)
return self.ppf(p, **kwargs)

Expand Down
13 changes: 11 additions & 2 deletions carl/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@
import theano.tensor as T

from theano.gof import graph
from theano.tensor.shared_randomstreams import RandomStreams

from . import DistributionMixin
from .base import check_random_state_theano
from .base import check_parameter
from .base import bound


def check_random_state(random_state):
if isinstance(random_state, RandomStreams):
return random_state
elif isinstance(random_state, np.random.RandomState):
random_state = random_state.randint(np.iinfo(np.int32).max)

return RandomStreams(seed=random_state)


class Mixture(DistributionMixin):
def __init__(self, components, weights=None,
random_state=None, optimizer=None):
Expand Down Expand Up @@ -86,7 +95,7 @@ def __init__(self, components, weights=None,

# randc
n_samples = T.iscalar()
rng = check_random_state_theano(self.random_state)
rng = check_random_state(self.random_state)
self.randc_ = rng.multinomial(size=(n_samples,), pvals=self.weights)
self.make_(self.randc_, "randc", args=[n_samples])

Expand Down

0 comments on commit 1501dcf

Please sign in to comment.