Skip to content

Commit

Permalink
Fix bug in plot_discrepancy for more than 6 parameters (#300)
Browse files Browse the repository at this point in the history
* Draft version of fixing bug in plot_discrepancy

* Fix bug with plot_discrepancy

* Add information to changelog

* Make number of plots in plot_discrepancy function more compact

* Improve plot_discrepancy for more than 11-d Gaussian models

* Fix adding y label in case if we have different number of columns

* Fix documentation
  • Loading branch information
b5y authored and hpesonen committed Dec 18, 2018
1 parent 04c0c9e commit 9b4b135
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

dev
---
- Fix bug in plot_discrepancy for more than 6 parameters

0.7.3 (2018-08-30)
------------------
- Fix bug in plot_pairs which crashes in case of 1 parameter
Expand Down
22 changes: 8 additions & 14 deletions elfi/methods/parameter_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,22 +1115,16 @@ def acq(x):
def plot_discrepancy(self, axes=None, **kwargs):
"""Plot acquired parameters vs. resulting discrepancy.
TODO: refactor
"""
n_plots = self.target_model.input_dim
ncols = kwargs.pop('ncols', 5)
kwargs['sharey'] = kwargs.get('sharey', True)
shape = (max(1, n_plots // ncols), min(n_plots, ncols))
axes, kwargs = vis._create_axes(axes, shape, **kwargs)
axes = axes.ravel()

for ii in range(n_plots):
axes[ii].scatter(self.target_model._gp.X[:, ii], self.target_model._gp.Y[:, 0])
axes[ii].set_xlabel(self.parameter_names[ii])
Parameters
----------
axes : plt.Axes or arraylike of plt.Axes
axes[0].set_ylabel('Discrepancy')
Return
------
axes : np.array of plt.Axes
return axes
"""
return vis.plot_discrepancy(self.target_model, self.parameter_names, axes=axes, **kwargs)


class BOLFI(BayesianOptimization):
Expand Down
38 changes: 38 additions & 0 deletions elfi/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,41 @@ def plot_params_vs_node(node, n_samples=100, func=None, seed=None, axes=None, **
axes[idx].set_axis_off()

return axes


def plot_discrepancy(gp, parameter_names, axes=None, **kwargs):
"""Plot acquired parameters vs. resulting discrepancy.
Parameters
----------
axes : plt.Axes or arraylike of plt.Axes
gp : GPyRegression target model, required
parameter_names : dict, required
Parameter names from model.parameters dict('parameter_name':(lower, upper), ... )`
Returns
-------
axes : np.array of plt.Axes
"""
n_plots = gp.input_dim
ncols = len(gp.bounds) if len(gp.bounds) < 5 else 5
ncols = kwargs.pop('ncols', ncols)
kwargs['sharey'] = kwargs.get('sharey', True)
if n_plots > 10:
shape = (1 + (1 + n_plots) // (ncols + 1), ncols)
else:
shape = (1 + n_plots // (ncols + 1), ncols)
axes, kwargs = _create_axes(axes, shape, **kwargs)
axes = axes.ravel()

for ii in range(n_plots):
axes[ii].scatter(gp.X[:, ii], gp.Y[:, 0], **kwargs)
axes[ii].set_xlabel(parameter_names[ii])
if ii % ncols == 0:
axes[ii].set_ylabel('Discrepancy')

for idx in range(len(parameter_names), len(axes)):
axes[idx].set_axis_off()

return axes

0 comments on commit 9b4b135

Please sign in to comment.