Skip to content

Commit

Permalink
Node output size refactoring (#127)
Browse files Browse the repository at this point in the history
All outputs are required to be arrays. Make the default output of Prior 1d vector. Do not require any longer that all vectors need to be at least 2d. Strip away unnecessary dimensions if possible to follow numpy convention.
  • Loading branch information
Jarno Lintusaari committed Apr 7, 2017
1 parent ca94307 commit f8ab3a8
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 20 deletions.
5 changes: 5 additions & 0 deletions elfi/bo/gpy_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def _make_gpy_instance(self, x, y, kernel, noise_var, mean_function):
def update(self, x, y, optimize=False):
"""Updates the GP model with new data
"""

# Must cast these as 2d for GPy
x = x.reshape((-1, self.input_dim))
y = y.reshape((-1, 1))

if self._gp is None:
self._init_gp(x, y)
else:
Expand Down
10 changes: 5 additions & 5 deletions elfi/methods/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,10 @@ def _get_batches_total(self):
def _init_state_samples(self, batch):
# Initialize the outputs dict based on the received batch
samples = {}
for output in self.outputs:
for node in self.outputs:
shape = (self.objective['n_samples'] + self.batch_size,) \
+ batch[output].shape[1:]
samples[output] = np.ones(shape) * np.inf
+ batch[node].shape[1:]
samples[node] = np.ones(shape) * np.inf
self.state['samples'] = samples

def _merge_batch(self, batch):
Expand All @@ -465,8 +465,8 @@ def _merge_batch(self, batch):
samples = self.state['samples']

# Put the acquired samples to the end
for k, v in samples.items():
v[self.objective['n_samples']:] = batch[k]
for node, v in samples.items():
v[self.objective['n_samples']:] = batch[node]

# Sort the smallest to the beginning
sort_mask = np.argsort(samples[self.discrepancy], axis=0).ravel()
Expand Down
6 changes: 3 additions & 3 deletions elfi/model/elfi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,15 +343,15 @@ def observed(self):


class ScipyLikeRV(StochasticMixin, NodeReference):
def __init__(self, distribution="uniform", *params, size=1, **kwargs):
def __init__(self, distribution="uniform", *params, size=None, **kwargs):
"""
Parameters
----------
distribution : str or scipy-like distribution object
params : params of the distribution
size : int, tuple or None
size of a single random draw. None means a scalar.
size : int, tuple or None, optional
size of a single random draw. None (default) means a scalar.
"""

Expand Down
21 changes: 11 additions & 10 deletions examples/ma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
"""


def MA2(t1, t2, n_obs=100, batch_size=1, random_state=None, w=None):
if w is None:
if random_state is None:
# Use the global random_state
random_state = np.random
w = random_state.randn(batch_size, n_obs+2) # i.i.d. sequence ~ N(0,1)

w = np.atleast_2d(w)
def MA2(t1, t2, n_obs=100, batch_size=1, random_state=None):
random_state = random_state or np.random
# i.i.d. sequence ~ N(0,1)
w = random_state.randn(batch_size, n_obs+2)

# Make inputs 2d arrays for broadcasting with w
t1 = np.atleast_2d(t1).reshape((-1, 1))
t2 = np.atleast_2d(t2).reshape((-1, 1))

x = w[:, 2:] + t1*w[:, 1:-1] + t2*w[:, :-2]
return x

Expand All @@ -24,12 +25,12 @@ def autocov(x, lag=1):
"""Autocovariance assuming a (weak) univariate stationary process with mean 0.
Realizations are in rows.
"""
C = np.mean(x[:, lag:]*x[:, :-lag], axis=1, keepdims=True)
C = np.mean(x[:, lag:]*x[:, :-lag], axis=1)
return C


def discrepancy(x, y):
d = np.linalg.norm(np.array(x) - np.array(y), ord=2, axis=0)
d = np.linalg.norm(np.column_stack(x) - np.column_stack(y), ord=2, axis=1)
return d


Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ def test_generate():
res = d.generate(n_gen)

assert res.shape[0] == n_gen
assert res.ndim == 2
assert res.shape[1] == 1
assert res.ndim == 1


@pytest.mark.usefixtures('with_all_clients')
Expand Down

0 comments on commit f8ab3a8

Please sign in to comment.