-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance posterior viz #58
Changes from 3 commits
c85ce97
6ab9a4b
dd278dc
c5eb245
70279b6
0d89a8e
4fb82be
3031d73
0314069
e09f175
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,7 +161,7 @@ class Inferencer(object): | |
key corresponds to the name of the output variable as defined | ||
in ``model`` and value corresponds to a single dimensional | ||
array of recorded data traces. | ||
features : list | ||
features : list, optional | ||
List of callables that take the voltage trace and output | ||
summary statistics. | ||
method : str, optional | ||
|
@@ -183,7 +183,7 @@ class Inferencer(object): | |
Dictionary of state variables to be initialized with respective | ||
values. | ||
""" | ||
def __init__(self, dt, model, input, output, features, method=None, | ||
def __init__(self, dt, model, input, output, features=None, method=None, | ||
threshold=None, reset=None, refractory=False, | ||
param_init=None): | ||
# time scale | ||
|
@@ -275,13 +275,16 @@ def __init__(self, dt, model, input, output, features, method=None, | |
# placeholder for the posterior | ||
self.posterior = None | ||
# observation the focus is on | ||
x_o = [] | ||
for o in self.output: | ||
o = np.array(o) | ||
obs = [] | ||
for feature in features: | ||
obs.extend(feature(o.transpose())) | ||
x_o.append(obs) | ||
if features: | ||
x_o = [] | ||
for o in self.output: | ||
o = np.array(o) | ||
obs = [] | ||
for feature in features: | ||
obs.extend(feature(o.transpose())) | ||
x_o.append(obs) | ||
else: | ||
x_o = np.vstack(self.output) | ||
x_o = torch.tensor(x_o, dtype=torch.float32) | ||
self.x_o = x_o | ||
self.features = features | ||
|
@@ -455,11 +458,15 @@ def extract_summary_statistics(self, theta, level=1): | |
x = [] | ||
for ov in self.output_var: | ||
x_val = obs[ov].get_value() | ||
summary_statistics = [] | ||
for feature in self.features: | ||
summary_statistics.append(feature(x_val)) | ||
x.append(summary_statistics) | ||
x = np.array(x, dtype=np.float32) | ||
if self.features: | ||
summary_statistics = [] | ||
for feature in self.features: | ||
summary_statistics.append(feature(x_val)) | ||
x.append(summary_statistics) | ||
x = np.array(x, dtype=np.float32) | ||
else: | ||
x.append(x_val) | ||
x = np.vstack(x).astype(np.float32) | ||
x = x.reshape((self.n_samples, -1)) | ||
return x | ||
|
||
|
@@ -893,7 +900,7 @@ def pairplot(self, samples=None, **kwargs): | |
Samples used to build the pairplot. | ||
**kwargs : dict, optional | ||
Additional keyword arguments for the | ||
``sbi.analysis.plot.pairplot`` function. | ||
``sbi.analysis.pairplot`` function. | ||
|
||
Returns | ||
------- | ||
|
@@ -911,6 +918,82 @@ def pairplot(self, samples=None, **kwargs): | |
fig, axes = sbi.analysis.pairplot(s, **kwargs) | ||
return fig, axes | ||
|
||
def conditional_pairplot(self, condition, limits, density=None, **kwargs): | ||
"""Plot conditional distribution given all other parameters. | ||
|
||
Check ``sbi.analysis.plot.conditional_pairplot`` for more | ||
details. | ||
|
||
Parameters | ||
---------- | ||
condition : torch.tensor | ||
Condition that all but the one/two regarded parameters are | ||
fixed to. | ||
limits : list or torch.tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't we make this optional and reuse the bounds here? |
||
Limits in between which each parameter will be evaulated. | ||
density : sbi.inference.NeuralPosterior, optional | ||
Posterior probability density. | ||
**kwargs : dict, optional | ||
Additional keyword arguments for the | ||
``sbi.analysis.conditional_pairplot`` function. | ||
|
||
Returns | ||
------- | ||
tuple | ||
Figure and axis of conditional pairplot. | ||
""" | ||
if density is not None: | ||
d = density | ||
else: | ||
try: | ||
d = self.posterior | ||
except AttributeError as e: | ||
print(e, '\nDensity is not available.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should rather be something like raise AttributeError('Density is not available.') from e instead of |
||
raise | ||
fig, axes = sbi.analysis.conditional_pairplot(density=d, | ||
condition=condition, | ||
limits=limits, | ||
*kwargs) | ||
return fig, axes | ||
|
||
def conditional_corrcoeff(self, condition, limits, density=None, **kwargs): | ||
"""Plot conditional distribution given all other parameters. | ||
|
||
Check ``sbi.analysis.conditional_density.conditional_corrcoeff`` | ||
for more details. | ||
|
||
Parameters | ||
---------- | ||
condition : torch.tensor | ||
Condition that all but the one/two regarded parameters are | ||
fixed to. | ||
limits : list or torch.tensor | ||
Limits in between which each parameter will be evaulated. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same remarks as above. |
||
density : sbi.inference.NeuralPosterior, optional | ||
Posterior probability density. | ||
**kwargs : dict, optional | ||
Additional keyword arguments for the | ||
``sbi.analysis.conditional_corrcoeff`` function. | ||
|
||
Returns | ||
------- | ||
torch.tensor | ||
Average conditional correlation matrix. | ||
""" | ||
if density is not None: | ||
d = density | ||
else: | ||
try: | ||
d = self.posterior | ||
except AttributeError as e: | ||
print(e, '\nDensity is not available.') | ||
raise | ||
cond_coeff = sbi.analysis.conditional_corrcoeff(density=d, | ||
condition=condition, | ||
limits=limits, | ||
*kwargs) | ||
return cond_coeff | ||
|
||
def generate_traces(self, posterior=None, output_var=None, param_init=None, | ||
level=0): | ||
"""Generates traces for a single drawn sample from the trained | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the rest of the code we tried to avoid that the user has to deal with pytorch tensors themselves. Maybe this could be something like a dictionary mapping parameter names to values (like the parameters in
generate
, for example)? But then this should be probably whatsample
returns as well...