Skip to content

Commit

Permalink
Plotting parameters vs. node output (#286)
Browse files Browse the repository at this point in the history
* Add support for giving seed to generate

* Implement elfi.plot_some for plotting parameters vs. node output

* Fix giving seed to generate

* Support giving a function to apply to node output

* Remove empty subplot axes

* Fix linting

* Rename function to 'plot_params_vs_node'
  • Loading branch information
vuolleko authored and hpesonen committed Aug 30, 2018
1 parent 98867a8 commit 33a7a88
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ dev
- Fix crashing summary and plots for samples with multivariate priors
- Add progress bar for inference methods
- Add method save to Sample objects
- Add support for giving seed to `generate`
- Implement elfi.plot_params_vs_node for plotting parameters vs. node output

0.7.2 (2018-06-20)
------------------
Expand Down
1 change: 1 addition & 0 deletions elfi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from elfi.model.extensions import ScipyLikeDistribution as Distribution
from elfi.store import OutputPool, ArrayPool
from elfi.visualization.visualization import nx_draw as draw
from elfi.visualization.visualization import plot_params_vs_node
from elfi.methods.bo.gpy_regression import GPyRegression

__author__ = 'ELFI authors'
Expand Down
11 changes: 8 additions & 3 deletions elfi/model/elfi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ def observed(self, observed):
"name as the key")
self.source_net.graph['observed'] = observed

def generate(self, batch_size=1, outputs=None, with_values=None):
"""Generate a batch of outputs using the global numpy seed.
def generate(self, batch_size=1, outputs=None, with_values=None, seed=None):
"""Generate a batch of outputs.
This method is useful for testing that the ELFI graph works.
Expand All @@ -271,6 +271,8 @@ def generate(self, batch_size=1, outputs=None, with_values=None):
outputs : list, optional
with_values : dict, optional
You can specify values for nodes to use when generating data
seed : int, optional
Defaults to global numpy seed.
"""
if outputs is None:
Expand All @@ -280,11 +282,14 @@ def generate(self, batch_size=1, outputs=None, with_values=None):
if not isinstance(outputs, list):
raise ValueError('Outputs must be a list of node names')

if seed is None:
seed = 'global'

pool = None
if with_values is not None:
pool = OutputPool(with_values.keys())
pool.add_batch(with_values, 0)
context = ComputationContext(batch_size, seed='global', pool=pool)
context = ComputationContext(batch_size, seed=seed, pool=pool)

client = elfi.client.get_client()
compiled_net = client.compile(self.source_net, outputs)
Expand Down
81 changes: 81 additions & 0 deletions elfi/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,84 @@ def progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100,
print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end='\r')
if iteration == total:
print()


def plot_params_vs_node(node, n_samples=100, func=None, seed=None, axes=None, **kwargs):
"""Plot some realizations of parameters vs. `node`.
Useful e.g. for exploring how a summary statistic varies with parameters.
Currently only nodes with scalar output are supported, though a function `func` can
be given to reduce node output. This allows giving the simulator as the `node` and
applying a summarizing function without incorporating it into the ELFI graph.
If `node` is one of the model parameters, its histogram is plotted.
Parameters
----------
node : elfi.NodeReference
The node which to evaluate. Its output must be scalar (shape=(batch_size,1)).
n_samples : int, optional
How many samples to plot.
func : callable, optional
A function to apply to node output.
seed : int, optional
axes : one or an iterable of plt.Axes, optional
Returns
-------
axes : np.array of plt.Axes
"""
model = node.model
parameters = model.parameter_names
node_name = node.name

if node_name in parameters:
outputs = [node_name]
shape = (1, 1)
bins = kwargs.pop('bins', 20)

else:
outputs = parameters + [node_name]
n_params = len(parameters)
ncols = n_params if n_params < 5 else 5
ncols = kwargs.pop('ncols', ncols)
edgecolor = kwargs.pop('edgecolor', 'none')
dot_size = kwargs.pop('s', 20)
shape = (1 + n_params // (ncols+1), ncols)

data = model.generate(batch_size=n_samples, outputs=outputs, seed=seed)

if func is not None:
if hasattr(func, '__name__'):
node_name = func.__name__
else:
node_name = 'func'
data[node_name] = func(data[node.name]) # leaves rest of the code unmodified

if data[node_name].shape != (n_samples,):
raise NotImplementedError("The plotted quantity must have shape ({},), was {}."
.format(n_samples, data[node_name].shape))

axes, kwargs = _create_axes(axes, shape, sharey=True, **kwargs)
axes = axes.ravel()

if len(outputs) == 1:
axes[0].hist(data[node_name], bins=bins, normed=True)
axes[0].set_xlabel(node_name)

else:
for idx, key in enumerate(parameters):
axes[idx].scatter(data[key],
data[node_name],
s=dot_size,
edgecolor=edgecolor,
**kwargs)

axes[idx].set_xlabel(key)
axes[0].set_ylabel(node_name)

for idx in range(len(parameters), len(axes)):
axes[idx].set_axis_off()

return axes

0 comments on commit 33a7a88

Please sign in to comment.