diff --git a/src/base.py b/src/base.py new file mode 100644 index 0000000..d6aa62c --- /dev/null +++ b/src/base.py @@ -0,0 +1,49 @@ +import torch +from torch import nn +import numpy as np + + +class BaseDistribution(nn.Module): + """ + Base distribution of a flow-based model + Parameters do not depend of target variable (as is the case for a VAE encoder) + """ + + def __init__(self, bounds, device="cpu"): + super().__init__() + # self.bounds = bounds + if isinstance(bounds, (list, np.ndarray)): + self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) + else: + raise ValueError("Unsupported map specification") + self.dim = self.bounds.shape[0] + self.device = device + + def sample(self, nsamples=1, **kwargs): + """Samples from base distribution + + Args: + num_samples: Number of samples to draw from the distriubtion + + Returns: + Samples drawn from the distribution + """ + raise NotImplementedError + + +class Uniform(BaseDistribution): + """ + Multivariate uniform distribution + """ + + def __init__(self, bounds, device="cpu"): + super().__init__(bounds, device) + self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0] + + def sample(self, nsamples=1, **kwargs): + u = ( + torch.rand((nsamples, self.dim), device=self.device) * self._rangebounds + + self.bounds[:, 0] + ) + log_detJ = torch.log(self._rangebounds).sum().repeat(nsamples) + return u, log_detJ diff --git a/src/integrators.py b/src/integrators.py index 7084970..65fce77 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -1,8 +1,10 @@ from typing import Callable, Union, List, Tuple, Dict import torch from utils import RAvg -from maps import Map, Affine +from maps import Map, Affine, CompositeMap +from base import Uniform import gvar +import numpy as np class Integrator: @@ -12,51 +14,60 @@ class Integrator: def __init__( self, - # bounds: Union[List[Tuple[float, float]], np.ndarray], - map, + bounds: Union[List[Tuple[float, float]], np.ndarray], + q0=None, + maps=None, neval: int = 1000, - batch_size: int = None, + nbatch: int = None, device="cpu", - # device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + adapt=False, ): - if not isinstance(map, Map): - map = Affine(map) - - self.dim = map.dim - self.map = map + self.adapt = adapt + self.dim = len(bounds) + if not q0: + q0 = Uniform(bounds, device=device) + self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) + self.q0 = q0 + self.maps = maps self.neval = neval - if batch_size is None: - self.batch_size = neval + if nbatch is None: + self.nbatch = neval self.neval = neval else: - self.batch_size = batch_size - self.neval = -(-neval // batch_size) * batch_size + self.nbatch = nbatch + self.neval = -(-neval // nbatch) * nbatch self.device = device def __call__(self, f: Callable, **kwargs): raise NotImplementedError("Subclasses must implement this method") + def sample(self, nsample, **kwargs): + u, log_detJ = self.q0.sample(nsample) + if not self.maps: + return u, log_detJ + else: + u, log_detj = self.maps.forward(u) + return u, log_detJ + log_detj + class MonteCarlo(Integrator): def __init__( self, - map, + bounds: Union[List[Tuple[float, float]], np.ndarray], + q0=None, + maps=None, nitn: int = 10, neval: int = 1000, - batch_size: int = None, + nbatch: int = None, device="cpu", adapt=False, - alpha=0.5, ): - super().__init__(map, neval, batch_size, device) - self.adapt = adapt - self.alpha = alpha + super().__init__(bounds, q0, maps, neval, nbatch, device, adapt) self.nitn = nitn def __call__(self, f: Callable, **kwargs): - u = torch.rand(self.batch_size, self.dim, device=self.device) - x, _ = self.map.forward(u) + x, _ = self.sample(self.nbatch) f_values = f(x) f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1 type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype) @@ -66,31 +77,23 @@ def __call__(self, f: Callable, **kwargs): # var = torch.zeros((f_size, f_size), dtype=type_fval, device=self.device) result = RAvg(weighted=self.adapt) - epoch = self.neval // self.batch_size + epoch = self.neval // self.nbatch for itn in range(self.nitn): mean[:] = 0 var[:] = 0 for _ in range(epoch): - y = torch.rand( - self.batch_size, self.dim, dtype=torch.float64, device=self.device - ) - x, jac = self.map.forward(y) - + x, log_detJ = self.sample(self.nbatch) f_values = f(x) - batch_results = self._multiply_by_jacobian(f_values, jac) + batch_results = self._multiply_by_jacobian( + f_values, torch.exp(log_detJ) + ) mean += torch.mean(batch_results, dim=-1) / epoch var += torch.var(batch_results, dim=-1) / (self.neval * epoch) - if self.adapt: - self.map.add_training_data(y, batch_results**2) - result.sum_neval += self.neval result.add(gvar.gvar(mean.item(), (var**0.5).item())) - if self.adapt: - self.map.adapt(alpha=self.alpha) - return result def _multiply_by_jacobian(self, values, jac): @@ -105,37 +108,46 @@ def _multiply_by_jacobian(self, values, jac): class MCMC(MonteCarlo): def __init__( self, - map: Map, + bounds: Union[List[Tuple[float, float]], np.ndarray], + q0=None, + maps=None, nitn: int = 10, neval=10000, - batch_size=None, - n_burnin=500, + nbatch=None, + nburnin=500, device="cpu", adapt=False, - alpha=0.5, ): - super().__init__(map, nitn, neval, batch_size, device, adapt, alpha) - self.n_burnin = n_burnin + super().__init__(bounds, q0, maps, nitn, neval, nbatch, device, adapt) + self.nburnin = nburnin + if maps is None: + self.maps = Affine([(0, 1)] * self.dim, device=device) + self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0] def __call__( self, f: Callable, - proposal_dist="global_uniform", + proposal_dist="uniform", thinning=1, mix_rate=0.0, **kwargs, ): epsilon = 1e-16 # Small value to ensure numerical stability - vars_shape = (self.batch_size, self.dim) - current_y = torch.rand(vars_shape, dtype=torch.float64, device=self.device) - current_x, current_jac = self.map.forward(current_y) + # vars_shape = (self.nbatch, self.dim) + current_y, current_jac = self.q0.sample(self.nbatch) + # if self.maps: + current_x, detJ = self.maps.forward(current_y) + current_jac += detJ + # else: + # current_x = current_y + current_jac = torch.exp(current_jac) current_fval = f(current_x) current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs() current_weight.masked_fill_(current_weight < epsilon, epsilon) # current_fval.masked_fill_(current_fval.abs() < epsilon, epsilon) proposed_y = torch.empty_like(current_y) - proposed_x = torch.empty_like(proposed_y) + proposed_x = torch.empty_like(current_x) new_fval = torch.empty_like(current_fval) new_weight = torch.empty_like(current_weight) @@ -149,12 +161,13 @@ def __call__( result = RAvg(weighted=self.adapt) result_ref = RAvg(weighted=self.adapt) - epoch = self.neval // self.batch_size + epoch = self.neval // self.nbatch n_meas = 0 for itn in range(self.nitn): for i in range(epoch): proposed_y[:] = self._propose(current_y, proposal_dist, **kwargs) - proposed_x[:], new_jac = self.map.forward(proposed_y) + proposed_x[:], new_jac = self.maps.forward(proposed_y) + new_jac = torch.exp(new_jac) new_fval[:] = f(proposed_x) new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs() @@ -162,7 +175,7 @@ def __call__( acceptance_probs = new_weight / current_weight * new_jac / current_jac accept = ( - torch.rand(self.batch_size, dtype=torch.float64, device=self.device) + torch.rand(self.nbatch, dtype=torch.float64, device=self.device) <= acceptance_probs ) @@ -171,7 +184,7 @@ def __call__( current_weight = torch.where(accept, new_weight, current_weight) current_jac = torch.where(accept, new_jac, current_jac) - if i < self.n_burnin and (self.adapt or itn == 0): + if i < self.nburnin and itn == 0: continue elif i % thinning == 0: n_meas += 1 @@ -184,28 +197,30 @@ def __call__( mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch var_ref += torch.var(batch_results_ref, dim=-1) / epoch - if self.adapt: - self.map.add_training_data( - current_y, (current_fval * current_jac) ** 2 - ) result.sum_neval += self.neval result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item())) - result_ref.sum_neval += self.batch_size + result_ref.sum_neval += self.nbatch result_ref.add( gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item()) ) - if self.adapt: - self.map.adapt(alpha=self.alpha) - - return result / result_ref + return result / result_ref * self._rangebounds.prod() def _propose(self, u, proposal_dist, **kwargs): if proposal_dist == "random_walk": step_size = kwargs.get("step_size", 0.2) - return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0 - elif proposal_dist == "global_uniform": - return torch.rand_like(u) + step_sizes = self._rangebounds * step_size + step = ( + torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes + ) + new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[ + :, 0 + ] + return new_u + # return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0 + elif proposal_dist == "uniform": + # return torch.rand_like(u) + return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0] # elif proposal_dist == "gaussian": # mean = kwargs.get("mean", torch.zeros_like(u)) # std = kwargs.get("std", torch.ones_like(u)) diff --git a/src/maps.py b/src/maps.py index 9ab2881..ef4a884 100644 --- a/src/maps.py +++ b/src/maps.py @@ -1,63 +1,67 @@ import torch import numpy as np +from torch import nn -class Map: - def __init__(self, map_spec, device="cpu"): - # if isinstance(map_spec, dict): - # self.map_spec = { - # k: torch.tensor(v, device=device) for k, v in map_spec.items() - # } - if isinstance(map_spec, (list, np.ndarray)): - self.map_spec = torch.tensor(map_spec, dtype=torch.float64, device=device) +class Map(nn.Module): + def __init__(self, bounds, device="cpu"): + super().__init__() + if isinstance(bounds, (list, np.ndarray)): + self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) else: raise ValueError("Unsupported map specification") - self.dim = self.map_spec.shape[0] + self.dim = self.bounds.shape[0] self.device = device - def forward(self, y): + def forward(self, u): raise NotImplementedError("Subclasses must implement this method") def inverse(self, x): raise NotImplementedError("Subclasses must implement this method") - def log_det_jacobian(self, y): - raise NotImplementedError("Subclasses must implement this method") - -class Affine(Map): - def __init__(self, map_spec, device="cpu"): - super().__init__(map_spec, device) - self._A = self.map_spec[:, 1] - self.map_spec[:, 0] - self._jac1 = torch.prod(self._A) +class CompositeMap(Map): + def __init__(self, maps, device="cpu"): + if not maps: + raise ValueError("Maps can not be empty.") + super().__init__(maps[-1].bounds, device) + self.maps = maps - def forward(self, y): - return y * self._A + self.map_spec[:, 0], self._jac1.repeat(y.shape[0]) + def forward(self, u): + log_detJ = torch.zeros(len(u), device=u.device) + for map in self.maps: + u, log_detj = map.forward(u) + log_detJ += log_detj + return u, log_detJ def inverse(self, x): - return (x - self.map_spec[:, 0]) / self._A, self._jac1.repeat(x.shape[0]) + log_detJ = torch.zeros(len(x), device=x.device) + for i in range(len(self.maps) - 1, -1, -1): + x, log_detj = self.maps[i].inverse(x) + log_detJ += log_detj + return x, log_detJ - def log_det_jacobian(self, y): - return torch.log(self._jac1) * y.shape[0] +class Affine(Map): + def __init__(self, bounds, device="cpu"): + super().__init__(bounds, device) + self._A = self.bounds[:, 1] - self.bounds[:, 0] + self._jac1 = torch.prod(self._A) -class AdaptiveMap(Map): - def __init__(self, map_spec, alpha=0.5, device="cpu"): - super().__init__(map_spec, device) - self.alpha = alpha - - def add_training_data(self, y, f): - pass + def forward(self, u): + return u * self._A + self.bounds[:, 0], torch.log(self._jac1.repeat(u.shape[0])) - def adapt(self, alpha=0.0): - pass + def inverse(self, x): + return (x - self.bounds[:, 0]) / self._A, torch.log( + self._jac1.repeat(x.shape[0]) + ) -class Vegas(AdaptiveMap): - def __init__(self, map_spec, ninc=1000, alpha=0.5, device="cpu"): - super().__init__(map_spec, alpha, device) +class Vegas(Map): + def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu"): + super().__init__(bounds, device) # self.nbin = nbin - + self.alpha = alpha if isinstance(ninc, int): self.ninc = torch.ones(self.dim, dtype=torch.int32, device=device) * ninc else: @@ -72,8 +76,8 @@ def __init__(self, map_spec, ninc=1000, alpha=0.5, device="cpu"): for d in range(self.dim): self.grid[d, : self.ninc[d] + 1] = torch.linspace( - self.map_spec[d, 0], - self.map_spec[d, 1], + self.bounds[d, 0], + self.bounds[d, 1], self.ninc[d] + 1, dtype=torch.float64, device=self.device, @@ -83,8 +87,8 @@ def __init__(self, map_spec, ninc=1000, alpha=0.5, device="cpu"): ) self.clear() - def add_training_data(self, y, f): - """Add training data ``f`` for ``y``-space points ``y``. + def add_training_data(self, u, f): + """Add training data ``f`` for ``u``-space points ``u``. Accumulates training data for later use by ``self.adapt()``. Grid increments will be made smaller in regions where @@ -93,19 +97,19 @@ def add_training_data(self, y, f): when ``f`` is constant across the grid. Args: - y (tensor): ``y`` values corresponding to the training data. - ``y`` is a contiguous 2-d tensor, where ``y[j, d]`` + u (tensor): ``u`` values corresponding to the training data. + ``u`` is a contiguous 2-d tensor, where ``u[j, d]`` is for points along direction ``d``. f (tensor): Training function values. ``f[j]`` corresponds to - point ``y[j, d]`` in ``y``-space. + point ``u[j, d]`` in ``u``-space. """ if self.sum_f is None: self.sum_f = torch.zeros_like(self.inc) self.n_f = torch.zeros_like(self.inc) + 1e-10 - iy = torch.floor(y * self.ninc).long() + iu = torch.floor(u * self.ninc).long() for d in range(self.dim): - self.sum_f[d, iy[:, d]] += torch.abs(f) - self.n_f[d, iy[:, d]] += 1 + self.sum_f[d, iu[:, d]] += torch.abs(f) + self.n_f[d, iu[:, d]] += 1 def adapt(self, alpha=0.0): """Adapt grid to accumulated training data. @@ -218,101 +222,78 @@ def clear(self): self.n_f = None @torch.no_grad() - def forward(self, y): - y = y.to(self.device) - y_ninc = y * self.ninc - iy = torch.floor(y_ninc).long() - dy_ninc = y_ninc - iy - - x = torch.empty_like(y) - jac = torch.ones(y.shape[0], device=x.device) + def forward(self, u): + u = u.to(self.device) + u_ninc = u * self.ninc + iu = torch.floor(u_ninc).long() + du_ninc = u_ninc - iu + + x = torch.empty_like(u) + jac = torch.ones(u.shape[0], device=x.device) # self.jac.fill_(1.0) for d in range(self.dim): - # Handle the case where iy < ninc + # Handle the case where iu < ninc ninc = self.ninc[d] - mask = iy[:, d] < ninc + mask = iu[:, d] < ninc if mask.any(): x[mask, d] = ( - self.grid[d, iy[mask, d]] - + self.inc[d, iy[mask, d]] * dy_ninc[mask, d] + self.grid[d, iu[mask, d]] + + self.inc[d, iu[mask, d]] * du_ninc[mask, d] ) - jac[mask] *= self.inc[d, iy[mask, d]] * ninc + jac[mask] *= self.inc[d, iu[mask, d]] * ninc - # Handle the case where iy >= ninc + # Handle the case where iu >= ninc mask_inv = ~mask if mask_inv.any(): x[mask_inv, d] = self.grid[d, ninc] jac[mask_inv] *= self.inc[d, ninc - 1] * ninc - return x, jac + return x, torch.log(jac) @torch.no_grad() def inverse(self, x): # self.jac.fill_(1.0) x = x.to(self.device) - y = torch.empty_like(x) + u = torch.empty_like(x) jac = torch.ones(x.shape[0], device=x.device) for d in range(self.dim): ninc = self.ninc[d] - iy = torch.searchsorted(self.grid[d, :], x[:, d].contiguous(), right=True) + iu = torch.searchsorted(self.grid[d, :], x[:, d].contiguous(), right=True) - mask_valid = (iy > 0) & (iy <= ninc) - mask_lower = iy <= 0 - mask_upper = iy > ninc + mask_valid = (iu > 0) & (iu <= ninc) + mask_lower = iu <= 0 + mask_upper = iu > ninc - # Handle valid range (0 < iy <= ninc) + # Handle valid range (0 < iu <= ninc) if mask_valid.any(): - iyi_valid = iy[mask_valid] - 1 - y[mask_valid, d] = ( - iyi_valid - + (x[mask_valid, d] - self.grid[d, iyi_valid]) - / self.inc[d, iyi_valid] + iui_valid = iu[mask_valid] - 1 + u[mask_valid, d] = ( + iui_valid + + (x[mask_valid, d] - self.grid[d, iui_valid]) + / self.inc[d, iui_valid] ) / ninc - jac[mask_valid] *= self.inc[d, iyi_valid] * ninc + jac[mask_valid] *= self.inc[d, iui_valid] * ninc - # Handle lower bound (iy <= 0)\ + # Handle lower bound (iu <= 0)\ if mask_lower.any(): - y[mask_lower, d] = 0.0 + u[mask_lower, d] = 0.0 jac[mask_lower] *= self.inc[d, 0] * ninc - # Handle upper bound (iy > ninc) + # Handle upper bound (iu > ninc) if mask_upper.any(): - y[mask_upper, d] = 1.0 + u[mask_upper, d] = 1.0 jac[mask_upper] *= self.inc[d, ninc - 1] * ninc - return y, jac + return u, torch.log(jac) - @torch.no_grad() - def log_det_jacobian(self, y): - y = y.to(self.device) - y_ninc = y * self.ninc - iy = torch.floor(y_ninc).long() - jac = torch.ones(y.shape[0], device=x.device) - for d in range(self.dim): - # Handle the case where iy < ninc - mask = iy[:, d] < self.ninc - if mask.any(): - jac[mask] *= self.inc[d, iy[mask, d]] * self.ninc +# class NormalizingFlow(Map): +# def __init__(self, bounds, flow_model, device="cpu"): +# super().__init__(bounds, device) +# self.flow_model = flow_model.to(device) - # Handle the case where iy >= ninc - mask_inv = ~mask - if mask_inv.any(): - jac[mask_inv] *= self.inc[d, self.ninc - 1] * self.ninc - - return torch.sum(torch.log(jac), dim=-1) - - -class NormalizingFlow(AdaptiveMap): - def __init__(self, map_spec, flow_model, alpha=0.5, device="cpu"): - super().__init__(map_spec, alpha, device) - self.flow_model = flow_model.to(device) - - def forward(self, u): - return self.flow_model.forward(u)[0] - - def inverse(self, x): - return self.flow_model.inverse(x)[0] +# def forward(self, u): +# return self.flow_model.forward(u) - def log_det_jacobian(self, u): - return self.flow_model.forward(u)[1] +# def inverse(self, x): +# return self.flow_model.inverse(x) diff --git a/src/mc_test.py b/src/mc_test.py index 932fe6a..700946e 100644 --- a/src/mc_test.py +++ b/src/mc_test.py @@ -5,6 +5,7 @@ set_seed(42) device = get_device() +# device = torch.device("cpu") def unit_circle_integrand(x): @@ -28,30 +29,34 @@ def half_sphere_integrand(x): dim = 2 -map_spec = [(-1, 1), (-1, 1)] +bounds = [(-1, 1), (-1, 1)] -affine_map = Affine(map_spec, device=device) -# vegas_map = Vegas(map_spec, device=device) +affine_map = Affine(bounds, device=device) +# vegas_map = Vegas(bounds, device=device) # Monte Carlo integration print("Calculate the area of the unit circle using Monte Carlo integration...") -mc_integrator = MonteCarlo(affine_map, neval=400000, batch_size=1000, device=device) +mc_integrator = MonteCarlo( + # bounds, maps=affine_map, neval=400000, nbatch=1000, device=device + bounds, + neval=400000, + nbatch=1000, + device=device, +) res = mc_integrator(unit_circle_integrand) print("Plain MC Integral results:") print(f" Integral: {res.mean}") print(f" Error: {res.sdev}") -res = MonteCarlo(map_spec, neval=400000, batch_size=1000, device=device)( +res = MonteCarlo(bounds, neval=400000, nbatch=1000, device=device)( unit_circle_integrand ) print("Plain MC Integral results:") print(f" Integral: {res.mean}") print(f" Error: {res.sdev}") -mcmc_integrator = MCMC( - map_spec, neval=400000, batch_size=1000, n_burnin=100, device=device -) +mcmc_integrator = MCMC(bounds, neval=400000, nbatch=1000, nburnin=100, device=device) res = mcmc_integrator(unit_circle_integrand, mix_rate=0.5) print("MCMC Integral results:") print(f" Integral: {res.mean}") @@ -65,9 +70,7 @@ def half_sphere_integrand(x): print(f" Integral: {res.mean}") print(f" Error: {res.sdev}") -mcmc_integrator = MCMC( - map_spec, neval=400000, batch_size=1000, n_burnin=100, device=device -) +mcmc_integrator = MCMC(bounds, neval=400000, nbatch=1000, nburnin=100, device=device) res = mcmc_integrator(half_sphere_integrand, mix_rate=0.5) print("MCMC Integral results:") print(f" Integral: {res.mean}")