From 123ea666af1fe7431b28c7030b2e4785e6ccd6d2 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Fri, 4 Oct 2024 00:31:56 -0400 Subject: [PATCH 1/5] api refactoring wip --- src/integrators.py | 42 +++++------ src/maps.py | 181 ++++++++++++++++++--------------------------- 2 files changed, 95 insertions(+), 128 deletions(-) diff --git a/src/integrators.py b/src/integrators.py index 7084970..522def1 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -13,24 +13,24 @@ class Integrator: def __init__( self, # bounds: Union[List[Tuple[float, float]], np.ndarray], - map, + map=None, neval: int = 1000, - batch_size: int = None, + nbatch: int = None, device="cpu", # device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): - if not isinstance(map, Map): - map = Affine(map) + #if not isinstance(map, Map): + # map = Affine(map) self.dim = map.dim self.map = map self.neval = neval - if batch_size is None: - self.batch_size = neval + if nbatch is None: + self.nbatch = neval self.neval = neval else: - self.batch_size = batch_size - self.neval = -(-neval // batch_size) * batch_size + self.nbatch = nbatch + self.neval = -(-neval // nbatch) * nbatch self.device = device @@ -44,18 +44,18 @@ def __init__( map, nitn: int = 10, neval: int = 1000, - batch_size: int = None, + nbatch: int = None, device="cpu", adapt=False, alpha=0.5, ): - super().__init__(map, neval, batch_size, device) + super().__init__(map, neval, nbatch, device) self.adapt = adapt self.alpha = alpha self.nitn = nitn def __call__(self, f: Callable, **kwargs): - u = torch.rand(self.batch_size, self.dim, device=self.device) + u = torch.rand(self.nbatch, self.dim, device=self.device) x, _ = self.map.forward(u) f_values = f(x) f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1 @@ -66,14 +66,14 @@ def __call__(self, f: Callable, **kwargs): # var = torch.zeros((f_size, f_size), dtype=type_fval, device=self.device) result = RAvg(weighted=self.adapt) - epoch = self.neval // self.batch_size + epoch = self.neval // self.nbatch for itn in range(self.nitn): mean[:] = 0 var[:] = 0 for _ in range(epoch): y = torch.rand( - self.batch_size, self.dim, dtype=torch.float64, device=self.device + self.nbatch, self.dim, dtype=torch.float64, device=self.device ) x, jac = self.map.forward(y) @@ -108,25 +108,25 @@ def __init__( map: Map, nitn: int = 10, neval=10000, - batch_size=None, + nbatch=None, n_burnin=500, device="cpu", adapt=False, alpha=0.5, ): - super().__init__(map, nitn, neval, batch_size, device, adapt, alpha) + super().__init__(map, nitn, neval, nbatch, device, adapt, alpha) self.n_burnin = n_burnin def __call__( self, f: Callable, - proposal_dist="global_uniform", + proposal_dist="uniform", thinning=1, mix_rate=0.0, **kwargs, ): epsilon = 1e-16 # Small value to ensure numerical stability - vars_shape = (self.batch_size, self.dim) + vars_shape = (self.nbatch, self.dim) current_y = torch.rand(vars_shape, dtype=torch.float64, device=self.device) current_x, current_jac = self.map.forward(current_y) current_fval = f(current_x) @@ -149,7 +149,7 @@ def __call__( result = RAvg(weighted=self.adapt) result_ref = RAvg(weighted=self.adapt) - epoch = self.neval // self.batch_size + epoch = self.neval // self.nbatch n_meas = 0 for itn in range(self.nitn): for i in range(epoch): @@ -162,7 +162,7 @@ def __call__( acceptance_probs = new_weight / current_weight * new_jac / current_jac accept = ( - torch.rand(self.batch_size, dtype=torch.float64, device=self.device) + torch.rand(self.nbatch, dtype=torch.float64, device=self.device) <= acceptance_probs ) @@ -190,7 +190,7 @@ def __call__( ) result.sum_neval += self.neval result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item())) - result_ref.sum_neval += self.batch_size + result_ref.sum_neval += self.nbatch result_ref.add( gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item()) ) @@ -204,7 +204,7 @@ def _propose(self, u, proposal_dist, **kwargs): if proposal_dist == "random_walk": step_size = kwargs.get("step_size", 0.2) return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0 - elif proposal_dist == "global_uniform": + elif proposal_dist == "uniform": return torch.rand_like(u) # elif proposal_dist == "gaussian": # mean = kwargs.get("mean", torch.zeros_like(u)) diff --git a/src/maps.py b/src/maps.py index 9ab2881..10e9769 100644 --- a/src/maps.py +++ b/src/maps.py @@ -1,63 +1,47 @@ import torch import numpy as np +from torch import nn - -class Map: - def __init__(self, map_spec, device="cpu"): - # if isinstance(map_spec, dict): - # self.map_spec = { - # k: torch.tensor(v, device=device) for k, v in map_spec.items() - # } - if isinstance(map_spec, (list, np.ndarray)): - self.map_spec = torch.tensor(map_spec, dtype=torch.float64, device=device) +class Map(nn.Module): + def __init__(self, bounds, device="cpu"): + super().__init__() + if isinstance(bounds, (list, np.ndarray)): + self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) else: raise ValueError("Unsupported map specification") - self.dim = self.map_spec.shape[0] + self.dim = self.bounds.shape[0] self.device = device - def forward(self, y): + def forward(self, u): raise NotImplementedError("Subclasses must implement this method") - + def inverse(self, x): raise NotImplementedError("Subclasses must implement this method") - - def log_det_jacobian(self, y): + + def sample(self, nsample): raise NotImplementedError("Subclasses must implement this method") + class Affine(Map): - def __init__(self, map_spec, device="cpu"): - super().__init__(map_spec, device) - self._A = self.map_spec[:, 1] - self.map_spec[:, 0] + def __init__(self, bounds, device="cpu"): + super().__init__(bounds, device) + self._A = self.bounds[:, 1] - self.bounds[:, 0] self._jac1 = torch.prod(self._A) - def forward(self, y): - return y * self._A + self.map_spec[:, 0], self._jac1.repeat(y.shape[0]) + def forward(self, u): + return u * self._A + self.bounds[:, 0], torch.log(self._jac1.repeat(u.shape[0])) def inverse(self, x): - return (x - self.map_spec[:, 0]) / self._A, self._jac1.repeat(x.shape[0]) + return (x - self.bounds[:, 0]) / self._A, torch.log(self._jac1.repeat(x.shape[0])) - def log_det_jacobian(self, y): - return torch.log(self._jac1) * y.shape[0] -class AdaptiveMap(Map): - def __init__(self, map_spec, alpha=0.5, device="cpu"): - super().__init__(map_spec, device) - self.alpha = alpha - - def add_training_data(self, y, f): - pass - - def adapt(self, alpha=0.0): - pass - - -class Vegas(AdaptiveMap): - def __init__(self, map_spec, ninc=1000, alpha=0.5, device="cpu"): - super().__init__(map_spec, alpha, device) +class Vegas(Map): + def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu"): + super().__init__(bounds, device) # self.nbin = nbin - + self.alpha = alpha if isinstance(ninc, int): self.ninc = torch.ones(self.dim, dtype=torch.int32, device=device) * ninc else: @@ -72,8 +56,8 @@ def __init__(self, map_spec, ninc=1000, alpha=0.5, device="cpu"): for d in range(self.dim): self.grid[d, : self.ninc[d] + 1] = torch.linspace( - self.map_spec[d, 0], - self.map_spec[d, 1], + self.bounds[d, 0], + self.bounds[d, 1], self.ninc[d] + 1, dtype=torch.float64, device=self.device, @@ -83,8 +67,8 @@ def __init__(self, map_spec, ninc=1000, alpha=0.5, device="cpu"): ) self.clear() - def add_training_data(self, y, f): - """Add training data ``f`` for ``y``-space points ``y``. + def add_training_data(self, u, f): + """Add training data ``f`` for ``u``-space points ``u``. Accumulates training data for later use by ``self.adapt()``. Grid increments will be made smaller in regions where @@ -93,19 +77,19 @@ def add_training_data(self, y, f): when ``f`` is constant across the grid. Args: - y (tensor): ``y`` values corresponding to the training data. - ``y`` is a contiguous 2-d tensor, where ``y[j, d]`` + u (tensor): ``u`` values corresponding to the training data. + ``u`` is a contiguous 2-d tensor, where ``u[j, d]`` is for points along direction ``d``. f (tensor): Training function values. ``f[j]`` corresponds to - point ``y[j, d]`` in ``y``-space. + point ``u[j, d]`` in ``u``-space. """ if self.sum_f is None: self.sum_f = torch.zeros_like(self.inc) self.n_f = torch.zeros_like(self.inc) + 1e-10 - iy = torch.floor(y * self.ninc).long() + iu = torch.floor(u * self.ninc).long() for d in range(self.dim): - self.sum_f[d, iy[:, d]] += torch.abs(f) - self.n_f[d, iy[:, d]] += 1 + self.sum_f[d, iu[:, d]] += torch.abs(f) + self.n_f[d, iu[:, d]] += 1 def adapt(self, alpha=0.0): """Adapt grid to accumulated training data. @@ -218,101 +202,84 @@ def clear(self): self.n_f = None @torch.no_grad() - def forward(self, y): - y = y.to(self.device) - y_ninc = y * self.ninc - iy = torch.floor(y_ninc).long() - dy_ninc = y_ninc - iy - - x = torch.empty_like(y) - jac = torch.ones(y.shape[0], device=x.device) + def forward(self, u): + u = u.to(self.device) + u_ninc = u * self.ninc + iu = torch.floor(u_ninc).long() + du_ninc = u_ninc - iu + + x = torch.empty_like(u) + jac = torch.ones(u.shape[0], device=x.device) # self.jac.fill_(1.0) for d in range(self.dim): - # Handle the case where iy < ninc + # Handle the case where iu < ninc ninc = self.ninc[d] - mask = iy[:, d] < ninc + mask = iu[:, d] < ninc if mask.any(): x[mask, d] = ( - self.grid[d, iy[mask, d]] - + self.inc[d, iy[mask, d]] * dy_ninc[mask, d] + self.grid[d, iu[mask, d]] + + self.inc[d, iu[mask, d]] * du_ninc[mask, d] ) - jac[mask] *= self.inc[d, iy[mask, d]] * ninc + jac[mask] *= self.inc[d, iu[mask, d]] * ninc - # Handle the case where iy >= ninc + # Handle the case where iu >= ninc mask_inv = ~mask if mask_inv.any(): x[mask_inv, d] = self.grid[d, ninc] jac[mask_inv] *= self.inc[d, ninc - 1] * ninc - return x, jac + return x, torch.log(jac) @torch.no_grad() def inverse(self, x): # self.jac.fill_(1.0) x = x.to(self.device) - y = torch.empty_like(x) + u = torch.empty_like(x) jac = torch.ones(x.shape[0], device=x.device) for d in range(self.dim): ninc = self.ninc[d] - iy = torch.searchsorted(self.grid[d, :], x[:, d].contiguous(), right=True) + iu = torch.searchsorted(self.grid[d, :], x[:, d].contiguous(), right=True) - mask_valid = (iy > 0) & (iy <= ninc) - mask_lower = iy <= 0 - mask_upper = iy > ninc + mask_valid = (iu > 0) & (iu <= ninc) + mask_lower = iu <= 0 + mask_upper = iu > ninc - # Handle valid range (0 < iy <= ninc) + # Handle valid range (0 < iu <= ninc) if mask_valid.any(): - iyi_valid = iy[mask_valid] - 1 - y[mask_valid, d] = ( - iyi_valid - + (x[mask_valid, d] - self.grid[d, iyi_valid]) - / self.inc[d, iyi_valid] + iui_valid = iu[mask_valid] - 1 + u[mask_valid, d] = ( + iui_valid + + (x[mask_valid, d] - self.grid[d, iui_valid]) + / self.inc[d, iui_valid] ) / ninc - jac[mask_valid] *= self.inc[d, iyi_valid] * ninc + jac[mask_valid] *= self.inc[d, iui_valid] * ninc - # Handle lower bound (iy <= 0)\ + # Handle lower bound (iu <= 0)\ if mask_lower.any(): - y[mask_lower, d] = 0.0 + u[mask_lower, d] = 0.0 jac[mask_lower] *= self.inc[d, 0] * ninc - # Handle upper bound (iy > ninc) + # Handle upper bound (iu > ninc) if mask_upper.any(): - y[mask_upper, d] = 1.0 + u[mask_upper, d] = 1.0 jac[mask_upper] *= self.inc[d, ninc - 1] * ninc - return y, jac - - @torch.no_grad() - def log_det_jacobian(self, y): - y = y.to(self.device) - y_ninc = y * self.ninc - iy = torch.floor(y_ninc).long() - - jac = torch.ones(y.shape[0], device=x.device) - for d in range(self.dim): - # Handle the case where iy < ninc - mask = iy[:, d] < self.ninc - if mask.any(): - jac[mask] *= self.inc[d, iy[mask, d]] * self.ninc - - # Handle the case where iy >= ninc - mask_inv = ~mask - if mask_inv.any(): - jac[mask_inv] *= self.inc[d, self.ninc - 1] * self.ninc + return u, torch.log(jac) - return torch.sum(torch.log(jac), dim=-1) + def sample(self, nsample): + return super() -class NormalizingFlow(AdaptiveMap): - def __init__(self, map_spec, flow_model, alpha=0.5, device="cpu"): - super().__init__(map_spec, alpha, device) +class NormalizingFlow(Map): + def __init__(self, bounds, flow_model, device="cpu"): + super().__init__(bounds, device) self.flow_model = flow_model.to(device) def forward(self, u): - return self.flow_model.forward(u)[0] + return self.flow_model.forward(u) def inverse(self, x): - return self.flow_model.inverse(x)[0] - - def log_det_jacobian(self, u): - return self.flow_model.forward(u)[1] + return self.flow_model.inverse(x) + + def sample(self, nsample): + return self.flow_model.sample(nsample) From cdde5234c82c77da37aa53374fb6e32a8a1d5eae Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Tue, 8 Oct 2024 20:44:10 -0400 Subject: [PATCH 2/5] wip --- src/integrators.py | 52 +++++++++++++++++++++------------------------- src/maps.py | 40 ++++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 36 deletions(-) diff --git a/src/integrators.py b/src/integrators.py index 522def1..417e75f 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -1,7 +1,7 @@ from typing import Callable, Union, List, Tuple, Dict import torch from utils import RAvg -from maps import Map, Affine +from maps import Map, Affine, NormalizingFlow import gvar @@ -13,17 +13,19 @@ class Integrator: def __init__( self, # bounds: Union[List[Tuple[float, float]], np.ndarray], - map=None, + maps: NormalizingFlow, neval: int = 1000, nbatch: int = None, device="cpu", + dtype = torch.float32, # device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): #if not isinstance(map, Map): # map = Affine(map) - self.dim = map.dim - self.map = map + self.dim = maps.dim + self.bounds = maps.bounds + self.maps = maps self.neval = neval if nbatch is None: self.nbatch = neval @@ -33,7 +35,7 @@ def __init__( self.neval = -(-neval // nbatch) * nbatch self.device = device - + self.dtype = dtype def __call__(self, f: Callable, **kwargs): raise NotImplementedError("Subclasses must implement this method") @@ -46,17 +48,14 @@ def __init__( neval: int = 1000, nbatch: int = None, device="cpu", - adapt=False, - alpha=0.5, ): super().__init__(map, neval, nbatch, device) - self.adapt = adapt - self.alpha = alpha self.nitn = nitn def __call__(self, f: Callable, **kwargs): - u = torch.rand(self.nbatch, self.dim, device=self.device) - x, _ = self.map.forward(u) + # u = torch.rand(self.nbatch, self.dim, device=self.device) + # x, _ = self.map.forward(u) + x,_ = self.maps.sample(self.nbatch) f_values = f(x) f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1 type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype) @@ -72,24 +71,19 @@ def __call__(self, f: Callable, **kwargs): mean[:] = 0 var[:] = 0 for _ in range(epoch): - y = torch.rand( - self.nbatch, self.dim, dtype=torch.float64, device=self.device - ) - x, jac = self.map.forward(y) - + # y = torch.rand( + # self.nbatch, self.dim, dtype=torch.float64, device=self.device + # ) + # x, jac = self.map.forward(y) + x, log_detJ = self.maps.sample(self.nbatch) f_values = f(x) - batch_results = self._multiply_by_jacobian(f_values, jac) + batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ) ) mean += torch.mean(batch_results, dim=-1) / epoch var += torch.var(batch_results, dim=-1) / (self.neval * epoch) - if self.adapt: - self.map.add_training_data(y, batch_results**2) - result.sum_neval += self.neval result.add(gvar.gvar(mean.item(), (var**0.5).item())) - if self.adapt: - self.map.adapt(alpha=self.alpha) return result @@ -109,13 +103,13 @@ def __init__( nitn: int = 10, neval=10000, nbatch=None, - n_burnin=500, + nburnin=500, device="cpu", adapt=False, alpha=0.5, ): super().__init__(map, nitn, neval, nbatch, device, adapt, alpha) - self.n_burnin = n_burnin + self.nburnin = nburnin def __call__( self, @@ -126,9 +120,11 @@ def __call__( **kwargs, ): epsilon = 1e-16 # Small value to ensure numerical stability - vars_shape = (self.nbatch, self.dim) - current_y = torch.rand(vars_shape, dtype=torch.float64, device=self.device) - current_x, current_jac = self.map.forward(current_y) + #vars_shape = (self.nbatch, self.dim) + # current_y = torch.rand(vars_shape, dtype=torch.float64, device=self.device) + # current_x, current_jac = self.map.forward(current_y) + current_x, current_jac = self.maps.sample(self.nbatch) + current_jac = torch.exp(current_jac) current_fval = f(current_x) current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs() current_weight.masked_fill_(current_weight < epsilon, epsilon) @@ -171,7 +167,7 @@ def __call__( current_weight = torch.where(accept, new_weight, current_weight) current_jac = torch.where(accept, new_jac, current_jac) - if i < self.n_burnin and (self.adapt or itn == 0): + if i < self.nburnin and (self.adapt or itn == 0): continue elif i % thinning == 0: n_meas += 1 diff --git a/src/maps.py b/src/maps.py index 10e9769..34e28dc 100644 --- a/src/maps.py +++ b/src/maps.py @@ -2,6 +2,37 @@ import numpy as np from torch import nn +class NormalizingFlow(nn.Module): + def __init__(self, q0, maps): + super().__init__() + if not maps: + raise ValueError("Maps can not be empty.") + self.q0 = q0 + self.maps = maps + self.dim = maps[0].dim + self.bounds = maps[0].bounds + def forward(self, u): + log_detJ = torch.zeros(len(u), device=u.device) + for map in self.flows: + u, log_detj = map.forward(u) + log_detJ += log_detj + return u, log_detJ + def inverse(self, x): + log_detJ = torch.zeros(len(x), device=x.device) + for i in range(len(self.maps) - 1, -1, -1): + x, log_detj = self.maps[i].inverse(x) + log_detJ += log_detj + return x, log_detJ + def sample(self, nsample): + u, log_detJ = self.q0.sample(nsample) + for map in self.maps: + u, log_detj = map(u) + log_detJ += log_detj + return u, log_detJ + + + + class Map(nn.Module): def __init__(self, bounds, device="cpu"): super().__init__() @@ -17,10 +48,6 @@ def forward(self, u): def inverse(self, x): raise NotImplementedError("Subclasses must implement this method") - - def sample(self, nsample): - raise NotImplementedError("Subclasses must implement this method") - class Affine(Map): @@ -266,8 +293,7 @@ def inverse(self, x): return u, torch.log(jac) - def sample(self, nsample): - return super() + class NormalizingFlow(Map): @@ -281,5 +307,3 @@ def forward(self, u): def inverse(self, x): return self.flow_model.inverse(x) - def sample(self, nsample): - return self.flow_model.sample(nsample) From 6e1468342233c4d28590061853cf02e615c5be73 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Thu, 10 Oct 2024 21:23:05 -0400 Subject: [PATCH 3/5] wip --- src/base.py | 29 ++++++++++++++++++++++++++++ src/integrators.py | 44 ++++++++++++++++++++++++++----------------- src/maps.py | 47 ++++++++++++++++++---------------------------- 3 files changed, 74 insertions(+), 46 deletions(-) create mode 100644 src/base.py diff --git a/src/base.py b/src/base.py new file mode 100644 index 0000000..6e9a8fe --- /dev/null +++ b/src/base.py @@ -0,0 +1,29 @@ +import torch +from torch import nn +class BaseDistribution(nn.Module): + """ + Base distribution of a flow-based model + Parameters do not depend of target variable (as is the case for a VAE encoder) + """ + + def __init__(self, bounds): + super().__init__() + self.bounds = bounds + def sample(self, nsamples=1, **kwargs): + """Samples from base distribution + + Args: + num_samples: Number of samples to draw from the distriubtion + + Returns: + Samples drawn from the distribution + """ + raise NotImplementedError + +class Uniform(BaseDistribution): + """ + Multivariate uniform distribution + """ + + def __init__(self, bounds): + super().__init__(bounds) diff --git a/src/integrators.py b/src/integrators.py index 417e75f..c9ec5bc 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -1,7 +1,8 @@ from typing import Callable, Union, List, Tuple, Dict import torch from utils import RAvg -from maps import Map, Affine, NormalizingFlow +from maps import Map, Affine, CompositeMap +from base import Uniform import gvar @@ -12,19 +13,17 @@ class Integrator: def __init__( self, - # bounds: Union[List[Tuple[float, float]], np.ndarray], - maps: NormalizingFlow, + bounds: List[Tuple[float, float]], #, np.ndarray], + q0 = None, + maps = None, neval: int = 1000, nbatch: int = None, device="cpu", - dtype = torch.float32, # device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): - #if not isinstance(map, Map): - # map = Affine(map) - - self.dim = maps.dim - self.bounds = maps.bounds + self.dim = len(bounds) + if not q0: + q0 = Uniform(bounds) self.maps = maps self.neval = neval if nbatch is None: @@ -35,10 +34,15 @@ def __init__( self.neval = -(-neval // nbatch) * nbatch self.device = device - self.dtype = dtype def __call__(self, f: Callable, **kwargs): raise NotImplementedError("Subclasses must implement this method") - + def sample(self, nsample, **kwargs): + u, log_detJ= self.q0.sample(nsample, **kwargs) + if not self.maps: + return u, log_detJ + else: + u, log_detj = self.maps.forward(u) + return u, log_detJ + log_detj class MonteCarlo(Integrator): def __init__( @@ -51,11 +55,11 @@ def __init__( ): super().__init__(map, neval, nbatch, device) self.nitn = nitn - + def __call__(self, f: Callable, **kwargs): # u = torch.rand(self.nbatch, self.dim, device=self.device) # x, _ = self.map.forward(u) - x,_ = self.maps.sample(self.nbatch) + x,_ = self.sample(self.nbatch) f_values = f(x) f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1 type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype) @@ -75,7 +79,7 @@ def __call__(self, f: Callable, **kwargs): # self.nbatch, self.dim, dtype=torch.float64, device=self.device # ) # x, jac = self.map.forward(y) - x, log_detJ = self.maps.sample(self.nbatch) + x, log_detJ = self.sample(self.nbatch) f_values = f(x) batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ) ) @@ -84,9 +88,15 @@ def __call__(self, f: Callable, **kwargs): result.sum_neval += self.neval result.add(gvar.gvar(mean.item(), (var**0.5).item())) - return result - + + # def sample(self, nsample): + # u, log_detJ = self.q0.sample(nsample) + # for map in self.maps: + # u, log_detj = map(u) + # log_detJ += log_detj + # return u, log_detJ + def _multiply_by_jacobian(self, values, jac): # if isinstance(values, dict): # return {k: v * torch.exp(log_det_J) for k, v in values.items()} @@ -123,7 +133,7 @@ def __call__( #vars_shape = (self.nbatch, self.dim) # current_y = torch.rand(vars_shape, dtype=torch.float64, device=self.device) # current_x, current_jac = self.map.forward(current_y) - current_x, current_jac = self.maps.sample(self.nbatch) + current_x, current_jac = self.sample(self.nbatch) current_jac = torch.exp(current_jac) current_fval = f(current_x) current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs() diff --git a/src/maps.py b/src/maps.py index 34e28dc..e4fdc0f 100644 --- a/src/maps.py +++ b/src/maps.py @@ -2,35 +2,6 @@ import numpy as np from torch import nn -class NormalizingFlow(nn.Module): - def __init__(self, q0, maps): - super().__init__() - if not maps: - raise ValueError("Maps can not be empty.") - self.q0 = q0 - self.maps = maps - self.dim = maps[0].dim - self.bounds = maps[0].bounds - def forward(self, u): - log_detJ = torch.zeros(len(u), device=u.device) - for map in self.flows: - u, log_detj = map.forward(u) - log_detJ += log_detj - return u, log_detJ - def inverse(self, x): - log_detJ = torch.zeros(len(x), device=x.device) - for i in range(len(self.maps) - 1, -1, -1): - x, log_detj = self.maps[i].inverse(x) - log_detJ += log_detj - return x, log_detJ - def sample(self, nsample): - u, log_detJ = self.q0.sample(nsample) - for map in self.maps: - u, log_detj = map(u) - log_detJ += log_detj - return u, log_detJ - - class Map(nn.Module): @@ -49,6 +20,24 @@ def forward(self, u): def inverse(self, x): raise NotImplementedError("Subclasses must implement this method") +class CompositeMap(Map): + def __init__(self, maps, device="cpu"): + if not maps: + raise ValueError("Maps can not be empty.") + super().__init__(maps[-1].bounds, device) + self.maps = maps + def forward(self, u): + log_detJ = torch.zeros(len(u), device=u.device) + for map in self.flows: + u, log_detj = map.forward(u) + log_detJ += log_detj + return u, log_detJ + def inverse(self, x): + log_detJ = torch.zeros(len(x), device=x.device) + for i in range(len(self.maps) - 1, -1, -1): + x, log_detj = self.maps[i].inverse(x) + log_detJ += log_detj + return x, log_detJ class Affine(Map): def __init__(self, bounds, device="cpu"): From 03e38e7dbe67b2a7bfc2ab6e5e152536f1b7a9c8 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Sat, 12 Oct 2024 01:09:00 -0400 Subject: [PATCH 4/5] mc_test is working --- src/base.py | 22 +++++++++++++--- src/integrators.py | 65 ++++++++++++++++++++++++++-------------------- src/maps.py | 37 +++++++++++++------------- src/mc_test.py | 16 +++++++----- 4 files changed, 83 insertions(+), 57 deletions(-) diff --git a/src/base.py b/src/base.py index 6e9a8fe..98764ad 100644 --- a/src/base.py +++ b/src/base.py @@ -1,14 +1,18 @@ import torch from torch import nn + + class BaseDistribution(nn.Module): """ Base distribution of a flow-based model Parameters do not depend of target variable (as is the case for a VAE encoder) """ - def __init__(self, bounds): + def __init__(self, bounds, device="cpu"): super().__init__() self.bounds = bounds + self.device = device + def sample(self, nsamples=1, **kwargs): """Samples from base distribution @@ -19,11 +23,21 @@ def sample(self, nsamples=1, **kwargs): Samples drawn from the distribution """ raise NotImplementedError - + + class Uniform(BaseDistribution): """ Multivariate uniform distribution """ - def __init__(self, bounds): - super().__init__(bounds) + def __init__(self, bounds, device="cpu"): + super().__init__(bounds, device) + + def sample(self, nsamples=1, **kwargs): + dim = len(self.bounds) + u = torch.rand((nsamples, dim), device=self.device) + log_detJ = torch.zeros(nsamples, device=self.device) + for i, bound in enumerate(self.bounds): + u[:, i] = (bound[1] - bound[0]) * u[:, i] + bound[0] + log_detJ += -torch.log(torch.tensor(bound[1] - bound[0])) + return u, log_detJ diff --git a/src/integrators.py b/src/integrators.py index c9ec5bc..435be5d 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -13,17 +13,20 @@ class Integrator: def __init__( self, - bounds: List[Tuple[float, float]], #, np.ndarray], - q0 = None, - maps = None, + bounds: List[Tuple[float, float]], # , np.ndarray], + q0=None, + maps=None, neval: int = 1000, nbatch: int = None, device="cpu", + adapt=False, # device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): + self.adapt = adapt self.dim = len(bounds) if not q0: q0 = Uniform(bounds) + self.q0 = q0 self.maps = maps self.neval = neval if nbatch is None: @@ -34,32 +37,38 @@ def __init__( self.neval = -(-neval // nbatch) * nbatch self.device = device + def __call__(self, f: Callable, **kwargs): raise NotImplementedError("Subclasses must implement this method") + def sample(self, nsample, **kwargs): - u, log_detJ= self.q0.sample(nsample, **kwargs) + u, log_detJ = self.q0.sample(nsample) if not self.maps: return u, log_detJ else: u, log_detj = self.maps.forward(u) - return u, log_detJ + log_detj + return u, log_detJ + log_detj + class MonteCarlo(Integrator): def __init__( self, - map, + bounds: List[Tuple[float, float]], # , np.ndarray], + q0=None, + maps=None, nitn: int = 10, neval: int = 1000, nbatch: int = None, device="cpu", + adapt=False, ): - super().__init__(map, neval, nbatch, device) + super().__init__(bounds, q0, maps, neval, nbatch, device, adapt) self.nitn = nitn - + def __call__(self, f: Callable, **kwargs): # u = torch.rand(self.nbatch, self.dim, device=self.device) # x, _ = self.map.forward(u) - x,_ = self.sample(self.nbatch) + x, _ = self.sample(self.nbatch) f_values = f(x) f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1 type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype) @@ -81,7 +90,9 @@ def __call__(self, f: Callable, **kwargs): # x, jac = self.map.forward(y) x, log_detJ = self.sample(self.nbatch) f_values = f(x) - batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ) ) + batch_results = self._multiply_by_jacobian( + f_values, torch.exp(log_detJ) + ) mean += torch.mean(batch_results, dim=-1) / epoch var += torch.var(batch_results, dim=-1) / (self.neval * epoch) @@ -89,14 +100,14 @@ def __call__(self, f: Callable, **kwargs): result.sum_neval += self.neval result.add(gvar.gvar(mean.item(), (var**0.5).item())) return result - + # def sample(self, nsample): # u, log_detJ = self.q0.sample(nsample) # for map in self.maps: # u, log_detj = map(u) # log_detJ += log_detj # return u, log_detJ - + def _multiply_by_jacobian(self, values, jac): # if isinstance(values, dict): # return {k: v * torch.exp(log_det_J) for k, v in values.items()} @@ -109,16 +120,17 @@ def _multiply_by_jacobian(self, values, jac): class MCMC(MonteCarlo): def __init__( self, - map: Map, + bounds: List[Tuple[float, float]], # , np.ndarray], + q0=None, + maps=None, nitn: int = 10, neval=10000, nbatch=None, nburnin=500, device="cpu", adapt=False, - alpha=0.5, ): - super().__init__(map, nitn, neval, nbatch, device, adapt, alpha) + super().__init__(bounds, q0, maps, nitn, neval, nbatch, device, adapt) self.nburnin = nburnin def __call__( @@ -130,10 +142,14 @@ def __call__( **kwargs, ): epsilon = 1e-16 # Small value to ensure numerical stability - #vars_shape = (self.nbatch, self.dim) - # current_y = torch.rand(vars_shape, dtype=torch.float64, device=self.device) - # current_x, current_jac = self.map.forward(current_y) - current_x, current_jac = self.sample(self.nbatch) + # vars_shape = (self.nbatch, self.dim) + current_y, current_jac = self.q0.sample(self.nbatch) + if self.maps: + current_x, jac = self.maps.forward(current_y) + current_jac += jac + else: + current_x = current_y + # current_x, current_jac = self.sample(self.nbatch) current_jac = torch.exp(current_jac) current_fval = f(current_x) current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs() @@ -160,7 +176,7 @@ def __call__( for itn in range(self.nitn): for i in range(epoch): proposed_y[:] = self._propose(current_y, proposal_dist, **kwargs) - proposed_x[:], new_jac = self.map.forward(proposed_y) + proposed_x[:], new_jac = self.maps.forward(proposed_y) new_fval[:] = f(proposed_x) new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs() @@ -177,7 +193,7 @@ def __call__( current_weight = torch.where(accept, new_weight, current_weight) current_jac = torch.where(accept, new_jac, current_jac) - if i < self.nburnin and (self.adapt or itn == 0): + if i < self.nburnin and itn == 0: continue elif i % thinning == 0: n_meas += 1 @@ -190,10 +206,6 @@ def __call__( mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch var_ref += torch.var(batch_results_ref, dim=-1) / epoch - if self.adapt: - self.map.add_training_data( - current_y, (current_fval * current_jac) ** 2 - ) result.sum_neval += self.neval result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item())) result_ref.sum_neval += self.nbatch @@ -201,9 +213,6 @@ def __call__( gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item()) ) - if self.adapt: - self.map.adapt(alpha=self.alpha) - return result / result_ref def _propose(self, u, proposal_dist, **kwargs): diff --git a/src/maps.py b/src/maps.py index e4fdc0f..ef4a884 100644 --- a/src/maps.py +++ b/src/maps.py @@ -3,7 +3,6 @@ from torch import nn - class Map(nn.Module): def __init__(self, bounds, device="cpu"): super().__init__() @@ -16,22 +15,25 @@ def __init__(self, bounds, device="cpu"): def forward(self, u): raise NotImplementedError("Subclasses must implement this method") - + def inverse(self, x): raise NotImplementedError("Subclasses must implement this method") + class CompositeMap(Map): def __init__(self, maps, device="cpu"): if not maps: raise ValueError("Maps can not be empty.") super().__init__(maps[-1].bounds, device) self.maps = maps + def forward(self, u): log_detJ = torch.zeros(len(u), device=u.device) - for map in self.flows: - u, log_detj = map.forward(u) - log_detJ += log_detj + for map in self.maps: + u, log_detj = map.forward(u) + log_detJ += log_detj return u, log_detJ + def inverse(self, x): log_detJ = torch.zeros(len(x), device=x.device) for i in range(len(self.maps) - 1, -1, -1): @@ -39,6 +41,7 @@ def inverse(self, x): log_detJ += log_detj return x, log_detJ + class Affine(Map): def __init__(self, bounds, device="cpu"): super().__init__(bounds, device) @@ -49,8 +52,9 @@ def forward(self, u): return u * self._A + self.bounds[:, 0], torch.log(self._jac1.repeat(u.shape[0])) def inverse(self, x): - return (x - self.bounds[:, 0]) / self._A, torch.log(self._jac1.repeat(x.shape[0])) - + return (x - self.bounds[:, 0]) / self._A, torch.log( + self._jac1.repeat(x.shape[0]) + ) class Vegas(Map): @@ -282,17 +286,14 @@ def inverse(self, x): return u, torch.log(jac) - - -class NormalizingFlow(Map): - def __init__(self, bounds, flow_model, device="cpu"): - super().__init__(bounds, device) - self.flow_model = flow_model.to(device) +# class NormalizingFlow(Map): +# def __init__(self, bounds, flow_model, device="cpu"): +# super().__init__(bounds, device) +# self.flow_model = flow_model.to(device) - def forward(self, u): - return self.flow_model.forward(u) +# def forward(self, u): +# return self.flow_model.forward(u) - def inverse(self, x): - return self.flow_model.inverse(x) - +# def inverse(self, x): +# return self.flow_model.inverse(x) diff --git a/src/mc_test.py b/src/mc_test.py index 932fe6a..c55df20 100644 --- a/src/mc_test.py +++ b/src/mc_test.py @@ -28,21 +28,23 @@ def half_sphere_integrand(x): dim = 2 -map_spec = [(-1, 1), (-1, 1)] +bounds = [(-1, 1), (-1, 1)] -affine_map = Affine(map_spec, device=device) -# vegas_map = Vegas(map_spec, device=device) +affine_map = Affine(bounds, device=device) +# vegas_map = Vegas(bounds, device=device) # Monte Carlo integration print("Calculate the area of the unit circle using Monte Carlo integration...") -mc_integrator = MonteCarlo(affine_map, neval=400000, batch_size=1000, device=device) +mc_integrator = MonteCarlo( + bounds, maps=affine_map, neval=400000, nbatch=1000, device=device +) res = mc_integrator(unit_circle_integrand) print("Plain MC Integral results:") print(f" Integral: {res.mean}") print(f" Error: {res.sdev}") -res = MonteCarlo(map_spec, neval=400000, batch_size=1000, device=device)( +res = MonteCarlo(bounds, neval=400000, nbatch=1000, device=device)( unit_circle_integrand ) print("Plain MC Integral results:") @@ -50,7 +52,7 @@ def half_sphere_integrand(x): print(f" Error: {res.sdev}") mcmc_integrator = MCMC( - map_spec, neval=400000, batch_size=1000, n_burnin=100, device=device + bounds, maps=affine_map, neval=400000, nbatch=1000, nburnin=100, device=device ) res = mcmc_integrator(unit_circle_integrand, mix_rate=0.5) print("MCMC Integral results:") @@ -66,7 +68,7 @@ def half_sphere_integrand(x): print(f" Error: {res.sdev}") mcmc_integrator = MCMC( - map_spec, neval=400000, batch_size=1000, n_burnin=100, device=device + bounds, maps=affine_map, neval=400000, nbatch=1000, nburnin=100, device=device ) res = mcmc_integrator(half_sphere_integrand, mix_rate=0.5) print("MCMC Integral results:") From 18610331d25f7de71122c918d6a19c863bfe038d Mon Sep 17 00:00:00 2001 From: houpc Date: Tue, 15 Oct 2024 02:21:02 +0800 Subject: [PATCH 5/5] bugfix and GPU compatibility --- src/base.py | 20 +++++++++++------ src/integrators.py | 56 +++++++++++++++++++++++----------------------- src/mc_test.py | 15 +++++++------ 3 files changed, 49 insertions(+), 42 deletions(-) diff --git a/src/base.py b/src/base.py index 98764ad..d6aa62c 100644 --- a/src/base.py +++ b/src/base.py @@ -1,5 +1,6 @@ import torch from torch import nn +import numpy as np class BaseDistribution(nn.Module): @@ -10,7 +11,12 @@ class BaseDistribution(nn.Module): def __init__(self, bounds, device="cpu"): super().__init__() - self.bounds = bounds + # self.bounds = bounds + if isinstance(bounds, (list, np.ndarray)): + self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) + else: + raise ValueError("Unsupported map specification") + self.dim = self.bounds.shape[0] self.device = device def sample(self, nsamples=1, **kwargs): @@ -32,12 +38,12 @@ class Uniform(BaseDistribution): def __init__(self, bounds, device="cpu"): super().__init__(bounds, device) + self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0] def sample(self, nsamples=1, **kwargs): - dim = len(self.bounds) - u = torch.rand((nsamples, dim), device=self.device) - log_detJ = torch.zeros(nsamples, device=self.device) - for i, bound in enumerate(self.bounds): - u[:, i] = (bound[1] - bound[0]) * u[:, i] + bound[0] - log_detJ += -torch.log(torch.tensor(bound[1] - bound[0])) + u = ( + torch.rand((nsamples, self.dim), device=self.device) * self._rangebounds + + self.bounds[:, 0] + ) + log_detJ = torch.log(self._rangebounds).sum().repeat(nsamples) return u, log_detJ diff --git a/src/integrators.py b/src/integrators.py index 435be5d..65fce77 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -4,6 +4,7 @@ from maps import Map, Affine, CompositeMap from base import Uniform import gvar +import numpy as np class Integrator: @@ -13,19 +14,19 @@ class Integrator: def __init__( self, - bounds: List[Tuple[float, float]], # , np.ndarray], + bounds: Union[List[Tuple[float, float]], np.ndarray], q0=None, maps=None, neval: int = 1000, nbatch: int = None, device="cpu", adapt=False, - # device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): self.adapt = adapt self.dim = len(bounds) if not q0: - q0 = Uniform(bounds) + q0 = Uniform(bounds, device=device) + self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) self.q0 = q0 self.maps = maps self.neval = neval @@ -53,7 +54,7 @@ def sample(self, nsample, **kwargs): class MonteCarlo(Integrator): def __init__( self, - bounds: List[Tuple[float, float]], # , np.ndarray], + bounds: Union[List[Tuple[float, float]], np.ndarray], q0=None, maps=None, nitn: int = 10, @@ -66,8 +67,6 @@ def __init__( self.nitn = nitn def __call__(self, f: Callable, **kwargs): - # u = torch.rand(self.nbatch, self.dim, device=self.device) - # x, _ = self.map.forward(u) x, _ = self.sample(self.nbatch) f_values = f(x) f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1 @@ -84,10 +83,6 @@ def __call__(self, f: Callable, **kwargs): mean[:] = 0 var[:] = 0 for _ in range(epoch): - # y = torch.rand( - # self.nbatch, self.dim, dtype=torch.float64, device=self.device - # ) - # x, jac = self.map.forward(y) x, log_detJ = self.sample(self.nbatch) f_values = f(x) batch_results = self._multiply_by_jacobian( @@ -101,13 +96,6 @@ def __call__(self, f: Callable, **kwargs): result.add(gvar.gvar(mean.item(), (var**0.5).item())) return result - # def sample(self, nsample): - # u, log_detJ = self.q0.sample(nsample) - # for map in self.maps: - # u, log_detj = map(u) - # log_detJ += log_detj - # return u, log_detJ - def _multiply_by_jacobian(self, values, jac): # if isinstance(values, dict): # return {k: v * torch.exp(log_det_J) for k, v in values.items()} @@ -120,7 +108,7 @@ def _multiply_by_jacobian(self, values, jac): class MCMC(MonteCarlo): def __init__( self, - bounds: List[Tuple[float, float]], # , np.ndarray], + bounds: Union[List[Tuple[float, float]], np.ndarray], q0=None, maps=None, nitn: int = 10, @@ -132,6 +120,9 @@ def __init__( ): super().__init__(bounds, q0, maps, nitn, neval, nbatch, device, adapt) self.nburnin = nburnin + if maps is None: + self.maps = Affine([(0, 1)] * self.dim, device=device) + self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0] def __call__( self, @@ -144,12 +135,11 @@ def __call__( epsilon = 1e-16 # Small value to ensure numerical stability # vars_shape = (self.nbatch, self.dim) current_y, current_jac = self.q0.sample(self.nbatch) - if self.maps: - current_x, jac = self.maps.forward(current_y) - current_jac += jac - else: - current_x = current_y - # current_x, current_jac = self.sample(self.nbatch) + # if self.maps: + current_x, detJ = self.maps.forward(current_y) + current_jac += detJ + # else: + # current_x = current_y current_jac = torch.exp(current_jac) current_fval = f(current_x) current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs() @@ -157,7 +147,7 @@ def __call__( # current_fval.masked_fill_(current_fval.abs() < epsilon, epsilon) proposed_y = torch.empty_like(current_y) - proposed_x = torch.empty_like(proposed_y) + proposed_x = torch.empty_like(current_x) new_fval = torch.empty_like(current_fval) new_weight = torch.empty_like(current_weight) @@ -177,6 +167,7 @@ def __call__( for i in range(epoch): proposed_y[:] = self._propose(current_y, proposal_dist, **kwargs) proposed_x[:], new_jac = self.maps.forward(proposed_y) + new_jac = torch.exp(new_jac) new_fval[:] = f(proposed_x) new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs() @@ -213,14 +204,23 @@ def __call__( gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item()) ) - return result / result_ref + return result / result_ref * self._rangebounds.prod() def _propose(self, u, proposal_dist, **kwargs): if proposal_dist == "random_walk": step_size = kwargs.get("step_size", 0.2) - return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0 + step_sizes = self._rangebounds * step_size + step = ( + torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes + ) + new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[ + :, 0 + ] + return new_u + # return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0 elif proposal_dist == "uniform": - return torch.rand_like(u) + # return torch.rand_like(u) + return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0] # elif proposal_dist == "gaussian": # mean = kwargs.get("mean", torch.zeros_like(u)) # std = kwargs.get("std", torch.ones_like(u)) diff --git a/src/mc_test.py b/src/mc_test.py index c55df20..700946e 100644 --- a/src/mc_test.py +++ b/src/mc_test.py @@ -5,6 +5,7 @@ set_seed(42) device = get_device() +# device = torch.device("cpu") def unit_circle_integrand(x): @@ -37,7 +38,11 @@ def half_sphere_integrand(x): print("Calculate the area of the unit circle using Monte Carlo integration...") mc_integrator = MonteCarlo( - bounds, maps=affine_map, neval=400000, nbatch=1000, device=device + # bounds, maps=affine_map, neval=400000, nbatch=1000, device=device + bounds, + neval=400000, + nbatch=1000, + device=device, ) res = mc_integrator(unit_circle_integrand) print("Plain MC Integral results:") @@ -51,9 +56,7 @@ def half_sphere_integrand(x): print(f" Integral: {res.mean}") print(f" Error: {res.sdev}") -mcmc_integrator = MCMC( - bounds, maps=affine_map, neval=400000, nbatch=1000, nburnin=100, device=device -) +mcmc_integrator = MCMC(bounds, neval=400000, nbatch=1000, nburnin=100, device=device) res = mcmc_integrator(unit_circle_integrand, mix_rate=0.5) print("MCMC Integral results:") print(f" Integral: {res.mean}") @@ -67,9 +70,7 @@ def half_sphere_integrand(x): print(f" Integral: {res.mean}") print(f" Error: {res.sdev}") -mcmc_integrator = MCMC( - bounds, maps=affine_map, neval=400000, nbatch=1000, nburnin=100, device=device -) +mcmc_integrator = MCMC(bounds, neval=400000, nbatch=1000, nburnin=100, device=device) res = mcmc_integrator(half_sphere_integrand, mix_rate=0.5) print("MCMC Integral results:") print(f" Integral: {res.mean}")