Skip to content

Commit

Permalink
Merge pull request #60 from brian-team/add_spike_support
Browse files Browse the repository at this point in the history
Add spike support
  • Loading branch information
akapet00 committed Aug 13, 2021
2 parents a9da86c + c10397c commit 15de949
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 129 deletions.
145 changes: 103 additions & 42 deletions brian2modelfitting/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from brian2.equations.equations import Equations
from brian2.groups.neurongroup import NeuronGroup
from brian2.input.timedarray import TimedArray
from brian2.monitors.spikemonitor import SpikeMonitor
from brian2.monitors.statemonitor import StateMonitor
from brian2.units.fundamentalunits import (DIMENSIONLESS,
fail_for_dimension_mismatch,
get_dimensions,
Quantity)
from brian2.utils.logger import get_logger
from brian2modelfitting.fitter import get_spikes
import numpy as np
from sbi.utils.get_nn_models import (posterior_nn,
likelihood_nn,
Expand All @@ -24,6 +27,10 @@
import torch

from .simulator import RuntimeSimulator, CPPStandaloneSimulator
from .utils import tqdm


logger = get_logger(__name__)


def configure_simulator():
Expand Down Expand Up @@ -219,7 +226,7 @@ def __init__(self, dt, model, input, output, features=None, method=None,
output_var = list(output.keys())
output = list(output.values())
for o_var in output_var:
if o_var not in model.names:
if o_var != 'spikes' and o_var not in model.names:
raise NameError(f'{o_var} is not a model variable')
self.output_var = output_var
self.output = output
Expand All @@ -234,11 +241,14 @@ def __init__(self, dt, model, input, output, features=None, method=None,
# handle multiple output variables
self.output_dim = []
for o_var, out in zip(self.output_var, self.output):
self.output_dim.append(model[o_var].dim)
fail_for_dimension_mismatch(out, self.output_dim[-1],
'The provided target values must have'
' the same units as the variable'
f' {o_var}')
if o_var == 'spikes':
self.output_dim.append(DIMENSIONLESS)
else:
self.output_dim.append(model[o_var].dim)
fail_for_dimension_mismatch(out, self.output_dim[-1],
'The provided target values must'
' have the same units as the'
f' variable {o_var}')

# add input to equations
self.model = model
Expand All @@ -250,11 +260,12 @@ def __init__(self, dt, model, input, output, features=None, method=None,
# add output to equations
counter = 0
for o_var, o_dim in zip(self.output_var, self.output_dim):
counter += 1
output_expr = f'output_var_{counter}(t, i % n_traces)'
output_dim = ('1' if o_dim is DIMENSIONLESS else repr(o_dim))
output_eqs = f'{o_var}_target = {output_expr} : {output_dim}'
self.model += output_eqs
if o_var != 'spikes':
counter += 1
output_expr = f'output_var_{counter}(t, i % n_traces)'
output_dim = ('1' if o_dim is DIMENSIONLESS else repr(o_dim))
output_eqs = f'{o_var}_target = {output_expr} : {output_dim}'
self.model += output_eqs

# create ``TimedArray`` object for input w.r.t. a given time scale
self.input_traces = TimedArray(input.transpose(), dt=self.dt)
Expand All @@ -276,16 +287,18 @@ def __init__(self, dt, model, input, output, features=None, method=None,
self.refractory = refractory

# observation the focus is on
for ov, o in zip(self.output_var, self.output):
o = np.array(o)
if features:
obs = []
obs = []
if features:
for ov, o in zip(self.output_var, self.output):
for _o in o:
for feature in features[ov]:
obs.append(feature(_o))
x_o = np.array(obs, dtype=np.float32)
else:
x_o = o.ravel().astype(np.float32)
x_o = np.array(obs, dtype=np.float32)
else:
for o in self.output:
o = np.array(o)
obs.append(o.ravel().astype(np.float32))
x_o = np.concatenate(obs)
self.x_o = x_o
self.features = features

Expand Down Expand Up @@ -353,10 +366,11 @@ def setup_simulator(self, network_name, n_neurons, output_var, param_init,
'n_traces': self.n_traces},
level=level+1)
counter = 0
for out in self.output:
counter += 1
namespace[f'output_var_{counter}'] = TimedArray(out.transpose(),
dt=self.dt)
for o_var, out in zip(self.output_var, self.output):
if o_var != 'spikes':
counter += 1
namespace[f'output_var_{counter}'] = TimedArray(out.T,
dt=self.dt)

# setup neuron group
kwds = {}
Expand All @@ -373,9 +387,18 @@ def setup_simulator(self, network_name, n_neurons, output_var, param_init,
namespace=namespace,
name='neurons',
**kwds)

# create a network of neurons
network = Network(neurons)
network.add(StateMonitor(source=neurons, variables=output_var,
record=True, dt=self.dt, name='statemonitor'))
if isinstance(output_var, str):
output_var = [output_var]
if 'spikes' in output_var:
network.add(SpikeMonitor(neurons, name='spikemonitor'))
record_vars = [v for v in output_var if v != 'spikes']
if len(record_vars):
network.add(StateMonitor(source=neurons, variables=record_vars,
record=True, dt=self.dt,
name='statemonitor'))

# initialize the simulator
simulator.initialize(network, param_init, name=network_name)
Expand Down Expand Up @@ -459,20 +482,38 @@ def extract_summary_statistics(self, theta, level=0):
name=network_name)

# extract features for each output variable and each trace
obs = simulator.statemonitor.recorded_variables
for ov in self.output_var:
o = obs[ov].get_value()
o = o.T
if self.features:
try:
obs = simulator.statemonitor.recorded_variables
except KeyError:
logger.warn('The state monitor object is not defined.',
name_suffix='statemonitor_definition')
if 'spikes' in self.output_var:
spike_trains = list(simulator.spikemonitor.spike_trains().values())
x = []
if self.features:
for ov in tqdm(self.output_var, desc='Extracting features',
total=len(self.output), leave=True):
summary_statistics = []
if ov != 'spikes':
o = obs[ov].get_value().T
# TODO: should be vectorized
for _o in o:
for feature in self.features[ov]:
summary_statistics.append(feature(_o))
x = np.array(summary_statistics, dtype=np.float32)
x = x.reshape(self.n_samples, -1)
else:
x = o.reshape(self.n_samples, -1).astype(np.float32)
for i in range(self.n_neurons):
if ov != 'spikes':
for feature in self.features[ov]:
summary_statistics.append(feature(o[i, :]))
else:
for feature in self.features[ov]:
summary_statistics.append(feature(spike_trains[i]))
_x = np.array(summary_statistics, dtype=np.float32)
_x = _x.reshape(self.n_samples, -1)
x.append(_x)
x = np.hstack(x)
else:
for ov in tqdm(self.output_var, desc='Aranging output traces',
total=len(self.output)):
o = obs[ov].get_value().T
x.append(o.reshape(self.n_samples, -1).astype(np.float32))
x = np.hstack(x)
return x

def save_summary_statistics(self, f, theta=None, x=None):
Expand Down Expand Up @@ -700,7 +741,7 @@ def infer_step(self, proposal, inference,
def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
inference_method='SNPE', density_estimator_model='maf',
inference_kwargs={}, train_kwargs={}, posterior_kwargs={},
**params):
restart=False, **params):
"""Return the trained posterior.
If ``theta`` and ``x`` are not provided, ``n_samples`` has to
Expand Down Expand Up @@ -736,6 +777,11 @@ def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
estimator.
posterior_kwargs : dict, optional
Additional keyword arguments for builing the posterior.
restart : bool, optional
When the method is called for a second time, set to True if
amortized inference should be performed. If False,
multi-round inference with the existing posterior will be
performed.
params : dict
Bounds for each parameter. Keys should correspond to names
of parameters as defined in the model equaions, values
Expand All @@ -747,6 +793,8 @@ def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
sbi.inference.NeuralPosterior
Approximated posterior distribution over parameters.
"""
if restart:
self.posterior = None
if self.posterior is None:
# handle the number of rounds
if not isinstance(n_rounds, int):
Expand Down Expand Up @@ -809,8 +857,11 @@ def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
proposal = prior

# main inference loop
for round in range(n_rounds):
print(f'Round {round + 1}/{n_rounds}.')
if self.posterior or n_rounds > 1:
tqdm_desc = f'{n_rounds}-round focused inference'
else:
tqdm_desc = 'Amortized inference'
for _ in tqdm(range(n_rounds), desc=tqdm_desc):

# inference step
posterior = self.infer_step(proposal, self.inference,
Expand All @@ -824,6 +875,7 @@ def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
if n_rounds > 1:
x_o = torch.tensor(self.x_o, dtype=torch.float32)
proposal = posterior.set_default_x(x_o)

self.posterior = posterior
return posterior

Expand Down Expand Up @@ -1175,7 +1227,7 @@ def generate_traces(self, posterior=None, output_var=None, param_init=None,
If a single output variable is observed, 2-D array of
traces generated by using a set of parameters sampled from
the trained posterior distribution of shape
(``n.traces``, number of time steps). Otherwise, a
(``self.n_traces``, number of time steps). Otherwise, a
dictionary with keys set to names of output variables, and
values to generated traces of respective output variables.
"""
Expand Down Expand Up @@ -1218,8 +1270,17 @@ def generate_traces(self, posterior=None, output_var=None, param_init=None,
# create dictionary of traces for multiple observed output variables
if len(output_var) > 1:
for ov in output_var:
trace = getattr(simulator.statemonitor, ov)[:]
traces = {ov: trace}
if ov == 'spikes':
trace = get_spikes(simulator.spikemonitor, 1,
self.n_traces)[0]
traces = {ov: trace}
else:
try:
trace = getattr(simulator.statemonitor, ov)[:]
traces = {ov: trace}
except KeyError:
logger.warn('No state monitor object found.'
' Call again with specified `output_var`.')
else:
traces = getattr(simulator.statemonitor, output_var[0])[:]
return traces

0 comments on commit 15de949

Please sign in to comment.