Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance posterior viz #58

Merged
merged 10 commits into from
Aug 5, 2021
113 changes: 98 additions & 15 deletions brian2modelfitting/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class Inferencer(object):
key corresponds to the name of the output variable as defined
in ``model`` and value corresponds to a single dimensional
array of recorded data traces.
features : list
features : list, optional
List of callables that take the voltage trace and output
summary statistics.
method : str, optional
Expand All @@ -183,7 +183,7 @@ class Inferencer(object):
Dictionary of state variables to be initialized with respective
values.
"""
def __init__(self, dt, model, input, output, features, method=None,
def __init__(self, dt, model, input, output, features=None, method=None,
threshold=None, reset=None, refractory=False,
param_init=None):
# time scale
Expand Down Expand Up @@ -275,13 +275,16 @@ def __init__(self, dt, model, input, output, features, method=None,
# placeholder for the posterior
self.posterior = None
# observation the focus is on
x_o = []
for o in self.output:
o = np.array(o)
obs = []
for feature in features:
obs.extend(feature(o.transpose()))
x_o.append(obs)
if features:
x_o = []
for o in self.output:
o = np.array(o)
obs = []
for feature in features:
obs.extend(feature(o.transpose()))
x_o.append(obs)
else:
x_o = np.vstack(self.output)
x_o = torch.tensor(x_o, dtype=torch.float32)
self.x_o = x_o
self.features = features
Expand Down Expand Up @@ -455,11 +458,15 @@ def extract_summary_statistics(self, theta, level=1):
x = []
for ov in self.output_var:
x_val = obs[ov].get_value()
summary_statistics = []
for feature in self.features:
summary_statistics.append(feature(x_val))
x.append(summary_statistics)
x = np.array(x, dtype=np.float32)
if self.features:
summary_statistics = []
for feature in self.features:
summary_statistics.append(feature(x_val))
x.append(summary_statistics)
x = np.array(x, dtype=np.float32)
else:
x.append(x_val)
x = np.vstack(x).astype(np.float32)
x = x.reshape((self.n_samples, -1))
return x

Expand Down Expand Up @@ -893,7 +900,7 @@ def pairplot(self, samples=None, **kwargs):
Samples used to build the pairplot.
**kwargs : dict, optional
Additional keyword arguments for the
``sbi.analysis.plot.pairplot`` function.
``sbi.analysis.pairplot`` function.

Returns
-------
Expand All @@ -911,6 +918,82 @@ def pairplot(self, samples=None, **kwargs):
fig, axes = sbi.analysis.pairplot(s, **kwargs)
return fig, axes

def conditional_pairplot(self, condition, limits, density=None, **kwargs):
"""Plot conditional distribution given all other parameters.

Check ``sbi.analysis.plot.conditional_pairplot`` for more
details.

Parameters
----------
condition : torch.tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the rest of the code we tried to avoid that the user has to deal with pytorch tensors themselves. Maybe this could be something like a dictionary mapping parameter names to values (like the parameters in generate, for example)? But then this should be probably what sample returns as well...

Condition that all but the one/two regarded parameters are
fixed to.
limits : list or torch.tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we make this optional and reuse the bounds here?

Limits in between which each parameter will be evaulated.
density : sbi.inference.NeuralPosterior, optional
Posterior probability density.
**kwargs : dict, optional
Additional keyword arguments for the
``sbi.analysis.conditional_pairplot`` function.

Returns
-------
tuple
Figure and axis of conditional pairplot.
"""
if density is not None:
d = density
else:
try:
d = self.posterior
except AttributeError as e:
print(e, '\nDensity is not available.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should rather be something like

raise AttributeError('Density is not available.') from e

instead of print + raise

raise
fig, axes = sbi.analysis.conditional_pairplot(density=d,
condition=condition,
limits=limits,
*kwargs)
return fig, axes

def conditional_corrcoeff(self, condition, limits, density=None, **kwargs):
"""Plot conditional distribution given all other parameters.

Check ``sbi.analysis.conditional_density.conditional_corrcoeff``
for more details.

Parameters
----------
condition : torch.tensor
Condition that all but the one/two regarded parameters are
fixed to.
limits : list or torch.tensor
Limits in between which each parameter will be evaulated.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same remarks as above.

density : sbi.inference.NeuralPosterior, optional
Posterior probability density.
**kwargs : dict, optional
Additional keyword arguments for the
``sbi.analysis.conditional_corrcoeff`` function.

Returns
-------
torch.tensor
Average conditional correlation matrix.
"""
if density is not None:
d = density
else:
try:
d = self.posterior
except AttributeError as e:
print(e, '\nDensity is not available.')
raise
cond_coeff = sbi.analysis.conditional_corrcoeff(density=d,
condition=condition,
limits=limits,
*kwargs)
return cond_coeff

def generate_traces(self, posterior=None, output_var=None, param_init=None,
level=0):
"""Generates traces for a single drawn sample from the trained
Expand Down