Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch import nn
import numpy as np


class BaseDistribution(nn.Module):
"""
Base distribution of a flow-based model
Parameters do not depend of target variable (as is the case for a VAE encoder)
"""

def __init__(self, bounds, device="cpu"):
super().__init__()
# self.bounds = bounds
if isinstance(bounds, (list, np.ndarray)):
self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device)
else:
raise ValueError("Unsupported map specification")
self.dim = self.bounds.shape[0]
self.device = device

def sample(self, nsamples=1, **kwargs):
"""Samples from base distribution

Args:
num_samples: Number of samples to draw from the distriubtion

Returns:
Samples drawn from the distribution
"""
raise NotImplementedError


class Uniform(BaseDistribution):
"""
Multivariate uniform distribution
"""

def __init__(self, bounds, device="cpu"):
super().__init__(bounds, device)
self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0]

def sample(self, nsamples=1, **kwargs):
u = (
torch.rand((nsamples, self.dim), device=self.device) * self._rangebounds
+ self.bounds[:, 0]
)
log_detJ = torch.log(self._rangebounds).sum().repeat(nsamples)
return u, log_detJ
139 changes: 77 additions & 62 deletions src/integrators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable, Union, List, Tuple, Dict
import torch
from utils import RAvg
from maps import Map, Affine
from maps import Map, Affine, CompositeMap
from base import Uniform
import gvar
import numpy as np


class Integrator:
Expand All @@ -12,51 +14,60 @@ class Integrator:

def __init__(
self,
# bounds: Union[List[Tuple[float, float]], np.ndarray],
map,
bounds: Union[List[Tuple[float, float]], np.ndarray],
q0=None,
maps=None,
neval: int = 1000,
batch_size: int = None,
nbatch: int = None,
device="cpu",
# device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
adapt=False,
):
if not isinstance(map, Map):
map = Affine(map)

self.dim = map.dim
self.map = map
self.adapt = adapt
self.dim = len(bounds)
if not q0:
q0 = Uniform(bounds, device=device)
self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device)
self.q0 = q0
self.maps = maps
self.neval = neval
if batch_size is None:
self.batch_size = neval
if nbatch is None:
self.nbatch = neval
self.neval = neval
else:
self.batch_size = batch_size
self.neval = -(-neval // batch_size) * batch_size
self.nbatch = nbatch
self.neval = -(-neval // nbatch) * nbatch

self.device = device

def __call__(self, f: Callable, **kwargs):
raise NotImplementedError("Subclasses must implement this method")

def sample(self, nsample, **kwargs):
u, log_detJ = self.q0.sample(nsample)
if not self.maps:
return u, log_detJ
else:
u, log_detj = self.maps.forward(u)
return u, log_detJ + log_detj


class MonteCarlo(Integrator):
def __init__(
self,
map,
bounds: Union[List[Tuple[float, float]], np.ndarray],
q0=None,
maps=None,
nitn: int = 10,
neval: int = 1000,
batch_size: int = None,
nbatch: int = None,
device="cpu",
adapt=False,
alpha=0.5,
):
super().__init__(map, neval, batch_size, device)
self.adapt = adapt
self.alpha = alpha
super().__init__(bounds, q0, maps, neval, nbatch, device, adapt)
self.nitn = nitn

def __call__(self, f: Callable, **kwargs):
u = torch.rand(self.batch_size, self.dim, device=self.device)
x, _ = self.map.forward(u)
x, _ = self.sample(self.nbatch)
f_values = f(x)
f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1
type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype)
Expand All @@ -66,31 +77,23 @@ def __call__(self, f: Callable, **kwargs):
# var = torch.zeros((f_size, f_size), dtype=type_fval, device=self.device)

result = RAvg(weighted=self.adapt)
epoch = self.neval // self.batch_size
epoch = self.neval // self.nbatch

for itn in range(self.nitn):
mean[:] = 0
var[:] = 0
for _ in range(epoch):
y = torch.rand(
self.batch_size, self.dim, dtype=torch.float64, device=self.device
)
x, jac = self.map.forward(y)

x, log_detJ = self.sample(self.nbatch)
f_values = f(x)
batch_results = self._multiply_by_jacobian(f_values, jac)
batch_results = self._multiply_by_jacobian(
f_values, torch.exp(log_detJ)
)

mean += torch.mean(batch_results, dim=-1) / epoch
var += torch.var(batch_results, dim=-1) / (self.neval * epoch)

if self.adapt:
self.map.add_training_data(y, batch_results**2)

result.sum_neval += self.neval
result.add(gvar.gvar(mean.item(), (var**0.5).item()))
if self.adapt:
self.map.adapt(alpha=self.alpha)

return result

def _multiply_by_jacobian(self, values, jac):
Expand All @@ -105,37 +108,46 @@ def _multiply_by_jacobian(self, values, jac):
class MCMC(MonteCarlo):
def __init__(
self,
map: Map,
bounds: Union[List[Tuple[float, float]], np.ndarray],
q0=None,
maps=None,
nitn: int = 10,
neval=10000,
batch_size=None,
n_burnin=500,
nbatch=None,
nburnin=500,
device="cpu",
adapt=False,
alpha=0.5,
):
super().__init__(map, nitn, neval, batch_size, device, adapt, alpha)
self.n_burnin = n_burnin
super().__init__(bounds, q0, maps, nitn, neval, nbatch, device, adapt)
self.nburnin = nburnin
if maps is None:
self.maps = Affine([(0, 1)] * self.dim, device=device)
self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0]

def __call__(
self,
f: Callable,
proposal_dist="global_uniform",
proposal_dist="uniform",
thinning=1,
mix_rate=0.0,
**kwargs,
):
epsilon = 1e-16 # Small value to ensure numerical stability
vars_shape = (self.batch_size, self.dim)
current_y = torch.rand(vars_shape, dtype=torch.float64, device=self.device)
current_x, current_jac = self.map.forward(current_y)
# vars_shape = (self.nbatch, self.dim)
current_y, current_jac = self.q0.sample(self.nbatch)
# if self.maps:
current_x, detJ = self.maps.forward(current_y)
current_jac += detJ
# else:
# current_x = current_y
current_jac = torch.exp(current_jac)
current_fval = f(current_x)
current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs()
current_weight.masked_fill_(current_weight < epsilon, epsilon)
# current_fval.masked_fill_(current_fval.abs() < epsilon, epsilon)

proposed_y = torch.empty_like(current_y)
proposed_x = torch.empty_like(proposed_y)
proposed_x = torch.empty_like(current_x)
new_fval = torch.empty_like(current_fval)
new_weight = torch.empty_like(current_weight)

Expand All @@ -149,20 +161,21 @@ def __call__(
result = RAvg(weighted=self.adapt)
result_ref = RAvg(weighted=self.adapt)

epoch = self.neval // self.batch_size
epoch = self.neval // self.nbatch
n_meas = 0
for itn in range(self.nitn):
for i in range(epoch):
proposed_y[:] = self._propose(current_y, proposal_dist, **kwargs)
proposed_x[:], new_jac = self.map.forward(proposed_y)
proposed_x[:], new_jac = self.maps.forward(proposed_y)
new_jac = torch.exp(new_jac)

new_fval[:] = f(proposed_x)
new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs()

acceptance_probs = new_weight / current_weight * new_jac / current_jac

accept = (
torch.rand(self.batch_size, dtype=torch.float64, device=self.device)
torch.rand(self.nbatch, dtype=torch.float64, device=self.device)
<= acceptance_probs
)

Expand All @@ -171,7 +184,7 @@ def __call__(
current_weight = torch.where(accept, new_weight, current_weight)
current_jac = torch.where(accept, new_jac, current_jac)

if i < self.n_burnin and (self.adapt or itn == 0):
if i < self.nburnin and itn == 0:
continue
elif i % thinning == 0:
n_meas += 1
Expand All @@ -184,28 +197,30 @@ def __call__(
mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch
var_ref += torch.var(batch_results_ref, dim=-1) / epoch

if self.adapt:
self.map.add_training_data(
current_y, (current_fval * current_jac) ** 2
)
result.sum_neval += self.neval
result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item()))
result_ref.sum_neval += self.batch_size
result_ref.sum_neval += self.nbatch
result_ref.add(
gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item())
)

if self.adapt:
self.map.adapt(alpha=self.alpha)

return result / result_ref
return result / result_ref * self._rangebounds.prod()

def _propose(self, u, proposal_dist, **kwargs):
if proposal_dist == "random_walk":
step_size = kwargs.get("step_size", 0.2)
return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0
elif proposal_dist == "global_uniform":
return torch.rand_like(u)
step_sizes = self._rangebounds * step_size
step = (
torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes
)
new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[
:, 0
]
return new_u
# return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0
elif proposal_dist == "uniform":
# return torch.rand_like(u)
return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0]
# elif proposal_dist == "gaussian":
# mean = kwargs.get("mean", torch.zeros_like(u))
# std = kwargs.get("std", torch.ones_like(u))
Expand Down
Loading