Skip to content

Commit

Permalink
Add pyright dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 10, 2022
1 parent c2ec23c commit dfae899
Show file tree
Hide file tree
Showing 16 changed files with 77 additions and 62 deletions.
13 changes: 13 additions & 0 deletions pyrightconfig.json
@@ -0,0 +1,13 @@
{
"include": [
"sbi"
],
"exclude": [
"**/__pycache__",
"**/node_modules",
".git",
],
"reportUnsupportedDunderAll": false,
"pythonVersion": "3.8",
"stubPath": ""
}
4 changes: 1 addition & 3 deletions sbi/inference/posteriors/base_posterior.py
Expand Up @@ -41,11 +41,9 @@ def __init__(
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
"""
if device is None:
device = potential_fn.device

# Ensure device string.
self._device = process_device(device)
self._device = process_device(potential_fn.device if device is None else device)

self.potential_fn = potential_fn

Expand Down
10 changes: 7 additions & 3 deletions sbi/inference/posteriors/direct_posterior.py
Expand Up @@ -4,6 +4,7 @@

import torch
from torch import Tensor, log, nn
from pyknos.nflows import flows

from sbi import utils as utils
from sbi.inference.posteriors.base_posterior import NeuralPosterior
Expand Down Expand Up @@ -32,8 +33,8 @@ class DirectPosterior(NeuralPosterior):

def __init__(
self,
posterior_estimator: nn.Module,
prior: Callable,
posterior_estimator: flows.Flow,
prior: Any,
max_sampling_batch_size: int = 10_000,
device: Optional[str] = None,
x_shape: Optional[torch.Size] = None,
Expand Down Expand Up @@ -232,9 +233,12 @@ def acceptance_at(x: Tensor) -> Tensor:
if is_new_x: # Calculate at x; don't save.
return acceptance_at(x)
elif not_saved_at_default_x or force_update: # Calculate at default_x; save.
assert self.default_x is not None
self._leakage_density_correction_factor = acceptance_at(self.default_x)
else:
raise ValueError

return self._leakage_density_correction_factor # type:ignore
return self._leakage_density_correction_factor

def map(
self,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Expand Up @@ -403,7 +403,7 @@ def _pyro_mcmc(
Returns: Tensor of shape (num_samples, shape_of_single_theta).
"""
num_chains = mp.cpu_count - 1 if num_chains is None else num_chains
num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains

kernels = dict(slice=Slice, hmc=HMC, nuts=NUTS)

Expand Down
3 changes: 2 additions & 1 deletion sbi/inference/potentials/posterior_based_potential.py
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.distributions.transforms as torch_tf
from torch import Tensor, nn
from pyknos.nflows import flows

from sbi.inference.potentials.base_potential import BasePotential
from sbi.utils import mcmc_transform
Expand Down Expand Up @@ -52,7 +53,7 @@ class PosteriorBasedPotential(BasePotential):

def __init__(
self,
posterior_estimator: nn.Module,
posterior_estimator: flows.Flow,
prior: Any,
x_o: Optional[Tensor],
device: str = "cpu",
Expand Down
10 changes: 6 additions & 4 deletions sbi/inference/snle/snle_a.py
Expand Up @@ -4,6 +4,8 @@

from typing import Any, Callable, Dict, Optional, Union

from pyknos.nflows import flows

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.snle.snle_base import LikelihoodEstimator
from sbi.types import TensorboardSummaryWriter
Expand Down Expand Up @@ -60,15 +62,15 @@ def train(
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: Optional[int] = None,
max_num_epochs: int = 2 ** 31 - 1,
clip_max_norm: Optional[float] = 5.0,
exclude_invalid_x: bool = True,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> NeuralPosterior:
) -> flows.Flow:
r"""Return density estimator that approximates the distribution $p(x|\theta)$.
Args:
Expand All @@ -78,8 +80,8 @@ def train(
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. If None, we
train until validation loss increases (see also `stop_after_epochs`).
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
Expand Down
11 changes: 5 additions & 6 deletions sbi/inference/snle/snle_base.py
Expand Up @@ -8,9 +8,10 @@

import torch
from torch import Tensor, nn, optim
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.writer import SummaryWriter
from pyknos.nflows import flows

from sbi import utils as utils
from sbi.inference import NeuralInference
Expand Down Expand Up @@ -119,15 +120,15 @@ def train(
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: Optional[int] = None,
max_num_epochs: int = 2 ** 31 - 1,
clip_max_norm: Optional[float] = 5.0,
exclude_invalid_x: bool = True,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> nn.Module:
) -> flows.Flow:
r"""Train the density estimator to learn the distribution $p(x|\theta)$.
Args:
Expand All @@ -151,8 +152,6 @@ def train(
Density estimator that has learned the distribution $p(x|\theta)$.
"""

max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs

# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and self._round > 0)
# Load data from most recent round.
Expand Down
14 changes: 7 additions & 7 deletions sbi/inference/snpe/snpe_a.py
Expand Up @@ -106,7 +106,7 @@ def train(
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: Optional[int] = None,
max_num_epochs: int = 2 ** 31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
exclude_invalid_x: bool = True,
Expand Down Expand Up @@ -135,8 +135,8 @@ def train(
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. If None, we
train until validation loss increases (see also `stop_after_epochs`).
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
calibration_kernel: A function to calibrate the loss with respect to the
Expand Down Expand Up @@ -217,7 +217,7 @@ def train(
def correct_for_proposal(
self,
density_estimator: Optional[TorchModule] = None,
) -> TorchModule:
) -> "SNPE_A_MDN":
r"""Build mixture of Gaussians that approximates the posterior.
Returns a `SNPE_A_MDN` object, which applies the posthoc-correction required in
Expand All @@ -242,7 +242,7 @@ def correct_for_proposal(
device = self._device
else:
# Otherwise, infer it from the device of the net parameters.
device = next(density_estimator.parameters()).device
device = str(next(density_estimator.parameters()).device)

# Set proposal of the density estimator.
# This also evokes the z-scoring correction if necessary.
Expand Down Expand Up @@ -377,7 +377,7 @@ class SNPE_A_MDN(nn.Module):
def __init__(
self,
flow: flows.Flow,
proposal: Union["utils.BoxUniform", "DirectPosterior", "SNPE_A_MDN"],
proposal: Union["utils.BoxUniform", "MultivariateNormal", "DirectPosterior"],
prior: Any,
device: str,
):
Expand Down Expand Up @@ -429,7 +429,7 @@ def log_prob(self, inputs, context=None):
theta = self._maybe_z_score_theta(inputs)

# Compute the log_prob of theta under the product.
log_prob_proposal_posterior = utils.sbiutils.mog_log_prob(
log_prob_proposal_posterior = utils.mog_log_prob(
theta, logits_pp, m_pp, prec_pp
)
utils.assert_all_finite(
Expand Down
12 changes: 5 additions & 7 deletions sbi/inference/snpe/snpe_base.py
Expand Up @@ -8,9 +8,9 @@

import torch
from torch import Tensor, nn, ones, optim
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.writer import SummaryWriter

from sbi import utils as utils
from sbi.inference import NeuralInference, check_if_proposal_has_default_x
Expand Down Expand Up @@ -158,7 +158,7 @@ def train(
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: Optional[int] = None,
max_num_epochs: int = 2 ** 31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
exclude_invalid_x: bool = True,
Expand All @@ -177,8 +177,8 @@ def train(
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. If None, we
train until validation loss increases (see also `stop_after_epochs`).
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
calibration_kernel: A function to calibrate the loss with respect to the
Expand Down Expand Up @@ -207,8 +207,6 @@ def train(
if calibration_kernel is None:
calibration_kernel = lambda x: ones([len(x)], device=self._device)

max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs

# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and self._round > 0)

Expand Down
15 changes: 8 additions & 7 deletions sbi/inference/snpe/snpe_c.py
Expand Up @@ -11,6 +11,7 @@
from torch.distributions import MultivariateNormal, Uniform

from sbi import utils as utils
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.snpe.snpe_base import PosteriorEstimator
from sbi.types import TensorboardSummaryWriter
from sbi.utils import (
Expand Down Expand Up @@ -93,7 +94,7 @@ def train(
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: Optional[int] = None,
max_num_epochs: int = 2 ** 31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
exclude_invalid_x: bool = True,
Expand All @@ -114,8 +115,8 @@ def train(
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. If None, we
train until validation loss increases (see also `stop_after_epochs`).
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
calibration_kernel: A function to calibrate the loss with respect to the
Expand Down Expand Up @@ -257,7 +258,7 @@ def _log_prob_proposal_posterior(
theta: Tensor,
x: Tensor,
masks: Tensor,
proposal: Optional[Any],
proposal: DirectPosterior,
) -> Tensor:
"""Return the log-probability of the proposal posterior.
Expand Down Expand Up @@ -305,8 +306,8 @@ def _log_prob_proposal_posterior_atomic(

batch_size = theta.shape[0]

num_atoms = clamp_and_warn(
"num_atoms", self._num_atoms, min_val=2, max_val=batch_size
num_atoms = int(
clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size)
)

# Each set of parameter atoms is evaluated using the same x,
Expand Down Expand Up @@ -357,7 +358,7 @@ def _log_prob_proposal_posterior_atomic(
return log_prob_proposal_posterior

def _log_prob_proposal_posterior_mog(
self, theta: Tensor, x: Tensor, proposal: "DirectPosterior"
self, theta: Tensor, x: Tensor, proposal: DirectPosterior
) -> Tensor:
"""Return log-probability of the proposal posterior for MoG proposal.
Expand Down
9 changes: 4 additions & 5 deletions sbi/inference/snre/snre_a.py
Expand Up @@ -3,7 +3,6 @@
import torch
from torch import Tensor, nn, ones

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.snre.snre_base import RatioEstimator
from sbi.types import TensorboardSummaryWriter
from sbi.utils import del_entries
Expand Down Expand Up @@ -57,15 +56,15 @@ def train(
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: Optional[int] = None,
max_num_epochs: int = 2 ** 31 - 1,
clip_max_norm: Optional[float] = 5.0,
exclude_invalid_x: bool = True,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> NeuralPosterior:
) -> nn.Module:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Args:
Expand All @@ -75,8 +74,8 @@ def train(
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. If None, we
train until validation loss increases (see also `stop_after_epochs`).
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
Expand Down

0 comments on commit dfae899

Please sign in to comment.