Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 50 additions & 84 deletions src/integrators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Union, List, Tuple, Dict
from typing import Callable
import torch
from utils import RAvg
from maps import Map, Linear, CompositeMap
from maps import Linear
from base import Uniform
import gvar
import numpy as np
Expand All @@ -27,7 +27,9 @@ def __init__(
self.dtype = dtype
if maps:
if not self.dtype == maps.dtype:
raise ValueError("Float type of maps should be same as integrator.")
raise ValueError(
"Data type of the variables of integrator should be same as maps."
)
self.bounds = maps.bounds
else:
if not isinstance(bounds, (list, np.ndarray)):
Expand Down Expand Up @@ -74,50 +76,31 @@ def __init__(
):
super().__init__(maps, bounds, q0, neval, nbatch, device, dtype)

def __call__(self, f: Callable, **kwargs):
x, _ = self.sample(1)
f_values = f(x)
if isinstance(f_values, (list, tuple)) and isinstance(
f_values[0], torch.Tensor
):
f_size = len(f_values)
type_fval = f_values[0].dtype
elif isinstance(f_values, torch.Tensor):
f_size = 1
type_fval = f_values.dtype
else:
raise TypeError(
"f must return a torch.Tensor or a list/tuple of torch.Tensor."
)
def __call__(self, f: Callable, f_dim: int = 1, **kwargs):
x, _ = self.sample(self.nbatch)
fx = torch.empty((self.nbatch, f_dim), dtype=self.dtype, device=self.device)

epoch = self.neval // self.nbatch
values = torch.zeros((self.nbatch, f_size), dtype=type_fval, device=self.device)
integ_values = torch.zeros(
(self.nbatch, f_dim), dtype=self.dtype, device=self.device
)

for iepoch in range(epoch):
for _ in range(epoch):
x, log_detJ = self.sample(self.nbatch)
f_values = f(x)
batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ))

values += batch_results / epoch

results = np.array([RAvg() for _ in range(f_size)])
for i in range(f_size):
_mean = values[:, i].mean().item()
_var = values[:, i].var().item() / self.nbatch
f(x, fx)
fx.mul_(log_detJ.exp_().unsqueeze_(1))
integ_values += fx / epoch

results = np.array([RAvg() for _ in range(f_dim)])
for i in range(f_dim):
_mean = integ_values[:, i].mean().item()
_var = integ_values[:, i].var().item() / self.nbatch
results[i].update(_mean, _var, self.neval)
if f_size == 1:
if f_dim == 1:
return results[0]
else:
return results

def _multiply_by_jacobian(self, values, jac):
# if isinstance(values, dict):
# return {k: v * torch.exp(log_det_J) for k, v in values.items()}
if isinstance(values, (list, tuple)):
return torch.stack([v * jac for v in values], dim=-1)
else:
return torch.stack([values * jac], dim=-1)


def random_walk(dim, bounds, device, dtype, u, **kwargs):
rangebounds = bounds[:, 1] - bounds[:, 0]
Expand Down Expand Up @@ -145,9 +128,9 @@ def __init__(
maps=None,
bounds=None,
q0=None,
neval=10000,
nbatch=None,
nburnin=500,
neval: int = 10000,
nbatch: int = None,
nburnin: int = 500,
device="cpu",
dtype=torch.float64,
):
Expand All @@ -160,102 +143,85 @@ def __init__(
def __call__(
self,
f: Callable,
f_dim: int = 1,
proposal_dist: Callable = uniform,
thinning=1,
mix_rate=0.5,
meas_freq: int = 1,
**kwargs,
):
epsilon = 1e-16 # Small value to ensure numerical stability
epoch = self.neval // self.nbatch
current_y, current_jac = self.q0.sample(self.nbatch)
current_x, detJ = self.maps.forward(current_y)
current_jac += detJ
current_jac = torch.exp(current_jac)
current_fval = f(current_x)
if isinstance(current_fval, (list, tuple)) and isinstance(
current_fval[0], torch.Tensor
):
f_size = len(current_fval)
current_fval = sum(current_fval)

def _integrand(x):
return sum(f(x))
elif isinstance(current_fval, torch.Tensor):
f_size = 1

def _integrand(x):
return f(x)
else:
raise TypeError(
"f must return a torch.Tensor or a list/tuple of torch.Tensor."
)
type_fval = current_fval.dtype
current_jac.exp_()
fx = torch.empty((self.nbatch, f_dim), dtype=self.dtype, device=self.device)
fx_weight = torch.empty(self.nbatch, dtype=self.dtype, device=self.device)
fx_weight[:] = f(current_x, fx)
fx_weight.abs_()

current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs()
current_weight = mix_rate / current_jac + (1 - mix_rate) * fx_weight
current_weight.masked_fill_(current_weight < epsilon, epsilon)

n_meas = epoch // thinning
n_meas = epoch // meas_freq

def one_step(current_y, current_x, current_weight, current_jac):
proposed_y = proposal_dist(
self.dim, self.bounds, self.device, self.dtype, current_y, **kwargs
)
proposed_x, new_jac = self.maps.forward(proposed_y)
new_jac = torch.exp(new_jac)
new_jac.exp_()

new_fval = _integrand(proposed_x)
new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs()
fx_weight[:] = f(proposed_x, fx)
fx_weight.abs_()
new_weight = mix_rate / new_jac + (1 - mix_rate) * fx_weight
new_weight.masked_fill_(new_weight < epsilon, epsilon)

acceptance_probs = new_weight / current_weight * new_jac / current_jac

accept = (
torch.rand(self.nbatch, dtype=self.dtype, device=self.device)
torch.rand(self.nbatch, dtype=torch.float64, device=self.device)
<= acceptance_probs
)

current_y = torch.where(accept.unsqueeze(1), proposed_y, current_y)
# current_fval = torch.where(accept, new_fval, current_fval)
current_x = torch.where(accept.unsqueeze(1), proposed_x, current_x)
current_weight = torch.where(accept, new_weight, current_weight)
current_jac = torch.where(accept, new_jac, current_jac)
return current_y, current_x, current_weight, current_jac

for i in range(self.nburnin):
for _ in range(self.nburnin):
current_y, current_x, current_weight, current_jac = one_step(
current_y, current_x, current_weight, current_jac
)

values = torch.zeros((self.nbatch, f_size), dtype=type_fval, device=self.device)
refvalues = torch.zeros(self.nbatch, dtype=type_fval, device=self.device)
values = torch.zeros((self.nbatch, f_dim), dtype=self.dtype, device=self.device)
refvalues = torch.zeros(self.nbatch, dtype=self.dtype, device=self.device)

for imeas in range(n_meas):
for j in range(thinning):
for _ in range(n_meas):
for _ in range(meas_freq):
current_y, current_x, current_weight, current_jac = one_step(
current_y, current_x, current_weight, current_jac
)
f(current_x, fx)

batch_results = self._multiply_by_jacobian(
f(current_x), 1.0 / current_weight
)
batch_results_ref = 1 / (current_jac * current_weight)

values += batch_results / n_meas
refvalues += batch_results_ref / n_meas
fx.div_(current_weight.unsqueeze(1))
values += fx / n_meas
refvalues += 1 / (current_jac * current_weight) / n_meas

results = np.array([RAvg() for _ in range(f_size)])
results = np.array([RAvg() for _ in range(f_dim)])
results_ref = RAvg()

mean_ref = refvalues.mean().item()
var_ref = refvalues.var().item() / self.nbatch

results_ref.update(mean_ref, var_ref, self.neval)
for i in range(f_size):
for i in range(f_dim):
_mean = values[:, i].mean().item()
_var = values[:, i].var().item() / self.nbatch
results[i].update(_mean, _var, self.neval)

if f_size == 1:
if f_dim == 1:
res = results[0] / results_ref * self._rangebounds.prod()
result = RAvg(itn_results=[res], sum_neval=self.neval)
return result
Expand Down
Loading