Skip to content

Commit

Permalink
Merge pull request #61 from brian-team/add_gpu_support
Browse files Browse the repository at this point in the history
enable gpu support
  • Loading branch information
akapet00 committed Aug 15, 2021
2 parents 0e936cb + 50c6197 commit bfb1305
Showing 1 changed file with 72 additions and 27 deletions.
99 changes: 72 additions & 27 deletions brian2modelfitting/inferencer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from numbers import Number
from typing import Mapping
import warnings

from brian2.core.functions import Function
from brian2.core.namespace import get_local_namespace
Expand Down Expand Up @@ -129,7 +130,7 @@ def calc_prior(param_names, **params):
Return
------
sbi.utils.torchutils.BoxUniform
``sbi`` compatible object that contains a uniform prior
``sbi``-compatible object that contains a uniform prior
distribution over a given set of parameters.
"""
for param_name in param_names:
Expand All @@ -151,9 +152,7 @@ class Inferencer(object):
It offers an interface similar to that of the `.Fitter` class but
instead of fitting, neural density estimator is trained using a
generative model which ultimately provides the posterior
distribution over free parameters.
This class serves as a wrapper for ``sbi`` library for inferencing
posterior over unknown parameters of a given model.
distribution over the unknown parameters.
Parameters
----------
Expand Down Expand Up @@ -309,6 +308,7 @@ def __init__(self, dt, model, input, output, features=None, method=None,
self.posterior = None
self.theta = None
self.x = None
self.sbi_device = 'cpu'

@property
def n_neurons(self):
Expand Down Expand Up @@ -574,7 +574,7 @@ def load_summary_statistics(self, f):
return (theta, x)

def init_inference(self, inference_method, density_estimator_model, prior,
**inference_kwargs):
sbi_device='cpu', **inference_kwargs):
"""Return instantiated inference object.
Parameters
Expand All @@ -587,13 +587,28 @@ def init_inference(self, inference_method, density_estimator_model, prior,
``linear``, ``mlp``, ``resnet`` for SNRE.
prior : sbi.utils.BoxUniform
Uniformly distributed prior over given parameters.
sbi_device : str, optional
Device on which the ``sbi`` will operate. By default this
is set to CPU and it is advisable to remain so for most
cases. In cases where the user provide custom embedding
network through ``inference_kwargs`` argument, which will
be trained more efficiently by using GPU, device should be
set accordingly.
inference_kwargs : dict, optional
Additional keyword arguments for
``sbi.utils.get_nn_models.posterior_nn`` method. The user
is free to provide own embedding network to learn features
from potentially high-dimensional simulation outputs. By
default multi-layer perceptron is used if no user defined
embedding network is provided.
Additional keyword arguments for different density
estimator builder functions:
``sbi.utils.get_nn_models.posterior_nn`` for SNPE,
``sbi.utils.get_nn_models.classifier_nn`` for SNRE, and
``sbi.utils.get_nn_models.likelihood_nn`` for SNLE. For
details check the ``sbi`` documentation. The most important
additional keyword augment the user is able to pass is
custom embedding network to learn features from potentially
high-dimensional simulation outputs. By default multi-layer
perceptron is used if no custom embedding network is
provided. For SNPE and SNLE,the user can pass an embedding
network for simulation outputs, while for SNRE, the user
may pass two embedding networks, one for parameters and one
for simulation outputs, respectively.
Returns
-------
Expand All @@ -616,9 +631,26 @@ def init_inference(self, inference_method, density_estimator_model, prior,
else:
density_estimator_builder = classifier_nn(
model=density_estimator_model, **inference_kwargs)
inference = inference_method_fun(prior, density_estimator_builder,
device='cpu',
show_progress_bars=True)

sbi_device = str.lower(sbi_device)
if sbi_device in ['cuda', 'gpu']:
if torch.cuda.is_available():
sbi_device = 'gpu'
self.sbi_device = 'cuda'
else:
logger.warn(f'Device {sbi_device} is not available.'
' Falling back to CPU.')
sbi_device = 'cpu'
self.sbi_device = sbi_device
else:
sbi_device = 'cpu'
self.sbi_device = sbi_device

with warnings.catch_warnings():
warnings.filterwarnings('ignore')
inference = inference_method_fun(prior, density_estimator_builder,
device=sbi_device,
show_progress_bars=True)
return inference

def train(self, inference, theta, x, *args, **train_kwargs):
Expand Down Expand Up @@ -646,7 +678,14 @@ def train(self, inference, theta, x, *args, **train_kwargs):
True inside the `train_kwargs`.
train_kwargs : dict, optional
Additional keyword arguments for ``train`` method of
``sbi.inference.NeuralInference`` object.
``sbi.inference.NeuralInference`` object. The user is able
to gain full control over training process by tuning
hyperparameters, i.e. batch size (by specifiying
``training_batch_size`` argument), learning rate
(``learning_rate``), validation fraction
(``validation_fraction``), number of training epochs
(``max_num_epochs``), etc. For details, check the ``sbi``
documentation.
Returns
-------
Expand All @@ -667,7 +706,8 @@ def build_posterior(self, inference, **posterior_kwargs):
simulation outputs prepared for training.
posterior_kwargs : dict, optional
Additional keyword arguments for ``build_posterior`` method
of ``sbi.inference.NeuralInference`` object.
in ``sbi.inference.NeuralInference`` classes. For details,
check the ``sbi`` documentation.
Returns
-------
Expand Down Expand Up @@ -710,7 +750,7 @@ def infer_step(self, proposal, inference,
sbi.inference.NeuralPosterior
Trained posterior.
"""
# extract the training data and make adjustments for ``sbi``
# extract the training data and make adjustments for the ``sbi``
if theta is None:
if n_samples is None:
raise ValueError('Either provide `theta` or `n_samples`.')
Expand All @@ -719,7 +759,7 @@ def infer_step(self, proposal, inference,
self.theta = theta
theta = torch.tensor(theta, dtype=torch.float32)

# extract the summary statistics and make adjustments for ``sbi``
# extract the summary statistics and make adjustments for the ``sbi``
if x is None:
if n_samples is None:
raise ValueError('Either provide `x` or `n_samples`.')
Expand All @@ -741,7 +781,7 @@ def infer_step(self, proposal, inference,
def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
inference_method='SNPE', density_estimator_model='maf',
inference_kwargs={}, train_kwargs={}, posterior_kwargs={},
restart=False, **params):
restart=False, device='cpu', **params):
"""Return the trained posterior.
If ``theta`` and ``x`` are not provided, ``n_samples`` has to
Expand Down Expand Up @@ -770,18 +810,23 @@ def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
``mdn``, ``made``, ``maf``, ``nsf`` for SNPE and SNLE, or
``linear``, ``mlp``, ``resnet`` for SNRE.
inference_kwargs : dict, optional
Additional keyword arguments for the inferencer method
definition.
Additional keyword arguments for the `.init_inference`.
train_kwargs : dict, optional
Additional keyword arguments for training the posterior
estimator.
Additional keyword arguments for `.train`.
posterior_kwargs : dict, optional
Additional keyword arguments for builing the posterior.
Additional keyword arguments for `.build_posterior`.
restart : bool, optional
When the method is called for a second time, set to True if
amortized inference should be performed. If False,
multi-round inference with the existing posterior will be
performed.
sbi_device : str, optional
Device on which the ``sbi`` will operate. By default this
is set to CPU and it is advisable to remain so for most
cases. In cases where the user provide custom embedding
network through ``inference_kwargs`` argument, which will
be trained more efficiently by using GPU, device should be
set accordingly.
params : dict
Bounds for each parameter. Keys should correspond to names
of parameters as defined in the model equaions, values
Expand Down Expand Up @@ -812,15 +857,15 @@ def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
# initialize prior
prior = self.init_prior(**params)

# extract the training data and make adjustments for ``sbi``
# extract the training data
if theta is None:
if n_samples is None:
raise ValueError('Either provide `theta` or `n_samples`.')
else:
theta = self.generate_training_data(n_samples, prior)
self.theta = theta

# extract the summary statistics and make adjustments for ``sbi``
# extract the summary statistics
if x is None:
if n_samples is None:
raise ValueError('Either provide `x` or `n_samples`.')
Expand All @@ -831,7 +876,7 @@ def infer(self, n_samples=None, theta=None, x=None, n_rounds=1,
# initialize inference object
self.inference = self.init_inference(inference_method,
density_estimator_model,
prior,
prior, device,
**inference_kwargs)

# args for SNPE in `.train`
Expand Down

0 comments on commit bfb1305

Please sign in to comment.