Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ class BaseDistribution(nn.Module):
Parameters do not depend of target variable (as is the case for a VAE encoder)
"""

def __init__(self, bounds, device="cpu"):
def __init__(self, bounds, device="cpu", dtype=torch.float64):
super().__init__()
self.dtype = dtype
# self.bounds = bounds
if isinstance(bounds, (list, np.ndarray)):
self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device)
self.bounds = torch.tensor(bounds, dtype=dtype, device=device)
else:
raise ValueError("Unsupported map specification")
self.dim = self.bounds.shape[0]
Expand All @@ -36,13 +37,14 @@ class Uniform(BaseDistribution):
Multivariate uniform distribution
"""

def __init__(self, bounds, device="cpu"):
super().__init__(bounds, device)
def __init__(self, bounds, device="cpu", dtype=torch.float64):
super().__init__(bounds, device, dtype)
self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0]

def sample(self, nsamples=1, **kwargs):
u = (
torch.rand((nsamples, self.dim), device=self.device) * self._rangebounds
torch.rand((nsamples, self.dim), device=self.device, dtype=self.dtype)
* self._rangebounds
+ self.bounds[:, 0]
)
log_detJ = torch.log(self._rangebounds).sum().repeat(nsamples)
Expand Down
234 changes: 132 additions & 102 deletions src/integrators.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,40 @@
from typing import Callable, Union, List, Tuple, Dict
import torch
from utils import RAvg
from maps import Map, Affine, CompositeMap
from maps import Map, Linear, CompositeMap
from base import Uniform
import gvar
import numpy as np


class Integrator:
"""
Base class for all integrators.
Base class for all integrators. This class is designed to handle integration tasks
over a specified domain (bounds) using a sampling method (q0) and optional
transformation maps.
"""

def __init__(
self,
bounds: Union[List[Tuple[float, float]], np.ndarray],
bounds,
q0=None,
maps=None,
neval: int = 1000,
nbatch: int = None,
device="cpu",
adapt=False,
dtype=torch.float64,
):
self.adapt = adapt
if not isinstance(bounds, (list, np.ndarray)):
raise TypeError("bounds must be a list or a NumPy array.")
self.dtype = dtype
self.dim = len(bounds)
if not q0:
q0 = Uniform(bounds, device=device)
self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device)
q0 = Uniform(bounds, device=device, dtype=dtype)
self.bounds = torch.tensor(bounds, dtype=dtype, device=device)
self.q0 = q0
if maps:
if not self.dtype == maps.dtype:
raise ValueError("Float type of maps should be same as integrator.")
self.maps = maps
self.neval = neval
if nbatch is None:
Expand All @@ -54,46 +61,41 @@ def sample(self, nsample, **kwargs):
class MonteCarlo(Integrator):
def __init__(
self,
bounds: Union[List[Tuple[float, float]], np.ndarray],
bounds,
q0=None,
maps=None,
nitn: int = 10,
neval: int = 1000,
nbatch: int = None,
device="cpu",
adapt=False,
dtype=torch.float64,
):
super().__init__(bounds, q0, maps, neval, nbatch, device, adapt)
self.nitn = nitn
super().__init__(bounds, q0, maps, neval, nbatch, device, dtype)

def __call__(self, f: Callable, **kwargs):
x, _ = self.sample(self.nbatch)
f_values = f(x)
f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1
type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype)

mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
var = torch.zeros(f_size, dtype=type_fval, device=self.device)
# type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype)
# mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
# var = torch.zeros(f_size, dtype=type_fval, device=self.device)
# var = torch.zeros((f_size, f_size), dtype=type_fval, device=self.device)

result = RAvg(weighted=self.adapt)
mean = torch.zeros(f_size, dtype=self.dtype, device=self.device)
var = torch.zeros(f_size, dtype=self.dtype, device=self.device)
result = RAvg()
epoch = self.neval // self.nbatch

for itn in range(self.nitn):
mean[:] = 0
var[:] = 0
for _ in range(epoch):
x, log_detJ = self.sample(self.nbatch)
f_values = f(x)
batch_results = self._multiply_by_jacobian(
f_values, torch.exp(log_detJ)
)
mean[:] = 0
var[:] = 0
for _ in range(epoch):
x, log_detJ = self.sample(self.nbatch)
f_values = f(x)
batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ))

mean += torch.mean(batch_results, dim=-1) / epoch
var += torch.var(batch_results, dim=-1) / (self.neval * epoch)
mean += torch.mean(batch_results, dim=-1) / epoch
var += torch.var(batch_results, dim=-1) / (self.neval * epoch)

result.sum_neval += self.neval
result.add(gvar.gvar(mean.item(), (var**0.5).item()))
result.sum_neval += self.neval
result.add(gvar.gvar(mean.item(), (var**0.5).item()))
return result

def _multiply_by_jacobian(self, values, jac):
Expand All @@ -105,29 +107,48 @@ def _multiply_by_jacobian(self, values, jac):
return values * jac


def random_walk(dim, bounds, device, dtype, u, **kwargs):
rangebounds = bounds[:, 1] - bounds[:, 0]
step_size = kwargs.get("step_size", 0.2)
step_sizes = rangebounds * step_size
step = torch.empty(dim, device=device, dtype=dtype).uniform_(-1, 1) * step_sizes
new_u = (u + step - bounds[:, 0]) % rangebounds + bounds[:, 0]
return new_u


def uniform(dim, bounds, device, dtype, u, **kwargs):
rangebounds = bounds[:, 1] - bounds[:, 0]
return torch.rand_like(u) * rangebounds + bounds[:, 0]


def gaussian(dim, bounds, device, dtype, u, **kwargs):
mean = kwargs.get("mean", torch.zeros_like(u))
std = kwargs.get("std", torch.ones_like(u))
return torch.normal(mean, std)


class MCMC(MonteCarlo):
def __init__(
self,
bounds: Union[List[Tuple[float, float]], np.ndarray],
bounds,
q0=None,
maps=None,
nitn: int = 10,
neval=10000,
nbatch=None,
nburnin=500,
device="cpu",
adapt=False,
dtype=torch.float64,
):
super().__init__(bounds, q0, maps, nitn, neval, nbatch, device, adapt)
super().__init__(bounds, q0, maps, neval, nbatch, device, dtype)
self.nburnin = nburnin
if maps is None:
self.maps = Affine([(0, 1)] * self.dim, device=device)
self.maps = Linear([(0, 1)] * self.dim, device=device)
self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0]

def __call__(
self,
f: Callable,
proposal_dist="uniform",
proposal_dist: Callable = uniform,
thinning=1,
mix_rate=0.0,
**kwargs,
Expand All @@ -146,84 +167,93 @@ def __call__(
current_weight.masked_fill_(current_weight < epsilon, epsilon)
# current_fval.masked_fill_(current_fval.abs() < epsilon, epsilon)

proposed_y = torch.empty_like(current_y)
proposed_x = torch.empty_like(current_x)
new_fval = torch.empty_like(current_fval)
new_weight = torch.empty_like(current_weight)
# proposed_y = torch.empty_like(current_y)
# proposed_x = torch.empty_like(current_x)
# new_fval = torch.empty_like(current_fval)
# new_weight = torch.empty_like(current_weight)

f_size = len(current_fval) if isinstance(current_fval, (list, tuple)) else 1
type_fval = current_fval.dtype if f_size == 1 else type(current_fval[0].dtype)
mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
# type_fval = current_fval.dtype if f_size == 1 else type(current_fval[0].dtype)
# mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
mean = torch.zeros(f_size, dtype=self.dtype, device=self.device)
mean_ref = torch.zeros_like(mean)
var = torch.zeros(f_size, dtype=type_fval, device=self.device)
# var = torch.zeros(f_size, dtype=type_fval, device=self.device)
var = torch.zeros(f_size, dtype=self.dtype, device=self.device)
var_ref = torch.zeros_like(mean)

result = RAvg(weighted=self.adapt)
result_ref = RAvg(weighted=self.adapt)
result = RAvg()
result_ref = RAvg()

epoch = self.neval // self.nbatch
n_meas = 0
for itn in range(self.nitn):
for i in range(epoch):
proposed_y[:] = self._propose(current_y, proposal_dist, **kwargs)
proposed_x[:], new_jac = self.maps.forward(proposed_y)
new_jac = torch.exp(new_jac)

new_fval[:] = f(proposed_x)
new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs()
def _propose(current_y, current_fval, current_weight, current_jac):
proposed_y = proposal_dist(
self.dim, self.bounds, self.device, self.dtype, current_y, **kwargs
)
proposed_x, new_jac = self.maps.forward(proposed_y)
new_jac = torch.exp(new_jac)

new_fval = f(proposed_x)
new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs()

acceptance_probs = new_weight / current_weight * new_jac / current_jac
acceptance_probs = new_weight / current_weight * new_jac / current_jac

accept = (
torch.rand(self.nbatch, dtype=torch.float64, device=self.device)
<= acceptance_probs
)
accept = (
torch.rand(self.nbatch, dtype=self.dtype, device=self.device)
<= acceptance_probs
)

current_y = torch.where(accept.unsqueeze(1), proposed_y, current_y)
current_fval = torch.where(accept, new_fval, current_fval)
current_weight = torch.where(accept, new_weight, current_weight)
current_jac = torch.where(accept, new_jac, current_jac)

if i < self.nburnin and itn == 0:
continue
elif i % thinning == 0:
n_meas += 1
batch_results = current_fval / current_weight

mean += torch.mean(batch_results, dim=-1) / epoch
var += torch.var(batch_results, dim=-1) / epoch

batch_results_ref = 1 / (current_jac * current_weight)
mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch
var_ref += torch.var(batch_results_ref, dim=-1) / epoch

result.sum_neval += self.neval
result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item()))
result_ref.sum_neval += self.nbatch
result_ref.add(
gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item())
current_y = torch.where(accept.unsqueeze(1), proposed_y, current_y)
current_fval = torch.where(accept, new_fval, current_fval)
current_weight = torch.where(accept, new_weight, current_weight)
current_jac = torch.where(accept, new_jac, current_jac)
return current_y, current_fval, current_weight, current_jac

for i in range(self.nburnin):
current_y, current_fval, current_weight, current_jac = _propose(
current_y, current_fval, current_weight, current_jac
)
for i in range(epoch // thinning):
for j in range(thinning):
current_y, current_fval, current_weight, current_jac = _propose(
current_y, current_fval, current_weight, current_jac
)
n_meas += 1
batch_results = current_fval / current_weight

mean += torch.mean(batch_results, dim=-1) / epoch
var += torch.var(batch_results, dim=-1) / epoch

batch_results_ref = 1 / (current_jac * current_weight)
mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch
var_ref += torch.var(batch_results_ref, dim=-1) / epoch

result.sum_neval += self.neval
result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item()))
result_ref.sum_neval += self.nbatch
result_ref.add(gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item()))

return result / result_ref * self._rangebounds.prod()

def _propose(self, u, proposal_dist, **kwargs):
if proposal_dist == "random_walk":
step_size = kwargs.get("step_size", 0.2)
step_sizes = self._rangebounds * step_size
step = (
torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes
)
new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[
:, 0
]
return new_u
# return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0
elif proposal_dist == "uniform":
# return torch.rand_like(u)
return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0]
# elif proposal_dist == "gaussian":
# mean = kwargs.get("mean", torch.zeros_like(u))
# std = kwargs.get("std", torch.ones_like(u))
# return torch.normal(mean, std)
else:
raise ValueError(f"Unknown proposal distribution: {proposal_dist}")
# def _propose(self, u, proposal_dist, **kwargs):
# if proposal_dist == "random_walk":
# step_size = kwargs.get("step_size", 0.2)
# step_sizes = self._rangebounds * step_size
# step = (
# torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes
# )
# new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[
# :, 0
# ]
# return new_u
# # return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0
# elif proposal_dist == "uniform":
# # return torch.rand_like(u)
# return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0]
# # elif proposal_dist == "gaussian":
# # mean = kwargs.get("mean", torch.zeros_like(u))
# # std = kwargs.get("std", torch.ones_like(u))
# # return torch.normal(mean, std)
# else:
# raise ValueError(f"Unknown proposal distribution: {proposal_dist}")
Loading