From 7123f22fb7d80f5412192d64eb8e05e13d5df9e6 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Tue, 22 Oct 2024 00:23:51 -0400 Subject: [PATCH] Refactor api wip --- src/base.py | 12 ++- src/integrators.py | 234 +++++++++++++++++++++++++-------------------- src/maps.py | 29 +++--- src/mc_test.py | 4 +- 4 files changed, 156 insertions(+), 123 deletions(-) diff --git a/src/base.py b/src/base.py index d6aa62c..ab7ffd0 100644 --- a/src/base.py +++ b/src/base.py @@ -9,11 +9,12 @@ class BaseDistribution(nn.Module): Parameters do not depend of target variable (as is the case for a VAE encoder) """ - def __init__(self, bounds, device="cpu"): + def __init__(self, bounds, device="cpu", dtype=torch.float64): super().__init__() + self.dtype = dtype # self.bounds = bounds if isinstance(bounds, (list, np.ndarray)): - self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) + self.bounds = torch.tensor(bounds, dtype=dtype, device=device) else: raise ValueError("Unsupported map specification") self.dim = self.bounds.shape[0] @@ -36,13 +37,14 @@ class Uniform(BaseDistribution): Multivariate uniform distribution """ - def __init__(self, bounds, device="cpu"): - super().__init__(bounds, device) + def __init__(self, bounds, device="cpu", dtype=torch.float64): + super().__init__(bounds, device, dtype) self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0] def sample(self, nsamples=1, **kwargs): u = ( - torch.rand((nsamples, self.dim), device=self.device) * self._rangebounds + torch.rand((nsamples, self.dim), device=self.device, dtype=self.dtype) + * self._rangebounds + self.bounds[:, 0] ) log_detJ = torch.log(self._rangebounds).sum().repeat(nsamples) diff --git a/src/integrators.py b/src/integrators.py index 65fce77..9e19fc9 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -1,7 +1,7 @@ from typing import Callable, Union, List, Tuple, Dict import torch from utils import RAvg -from maps import Map, Affine, CompositeMap +from maps import Map, Linear, CompositeMap from base import Uniform import gvar import numpy as np @@ -9,25 +9,32 @@ class Integrator: """ - Base class for all integrators. + Base class for all integrators. This class is designed to handle integration tasks + over a specified domain (bounds) using a sampling method (q0) and optional + transformation maps. """ def __init__( self, - bounds: Union[List[Tuple[float, float]], np.ndarray], + bounds, q0=None, maps=None, neval: int = 1000, nbatch: int = None, device="cpu", - adapt=False, + dtype=torch.float64, ): - self.adapt = adapt + if not isinstance(bounds, (list, np.ndarray)): + raise TypeError("bounds must be a list or a NumPy array.") + self.dtype = dtype self.dim = len(bounds) if not q0: - q0 = Uniform(bounds, device=device) - self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) + q0 = Uniform(bounds, device=device, dtype=dtype) + self.bounds = torch.tensor(bounds, dtype=dtype, device=device) self.q0 = q0 + if maps: + if not self.dtype == maps.dtype: + raise ValueError("Float type of maps should be same as integrator.") self.maps = maps self.neval = neval if nbatch is None: @@ -54,46 +61,41 @@ def sample(self, nsample, **kwargs): class MonteCarlo(Integrator): def __init__( self, - bounds: Union[List[Tuple[float, float]], np.ndarray], + bounds, q0=None, maps=None, - nitn: int = 10, neval: int = 1000, nbatch: int = None, device="cpu", - adapt=False, + dtype=torch.float64, ): - super().__init__(bounds, q0, maps, neval, nbatch, device, adapt) - self.nitn = nitn + super().__init__(bounds, q0, maps, neval, nbatch, device, dtype) def __call__(self, f: Callable, **kwargs): x, _ = self.sample(self.nbatch) f_values = f(x) f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1 - type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype) - - mean = torch.zeros(f_size, dtype=type_fval, device=self.device) - var = torch.zeros(f_size, dtype=type_fval, device=self.device) + # type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype) + # mean = torch.zeros(f_size, dtype=type_fval, device=self.device) + # var = torch.zeros(f_size, dtype=type_fval, device=self.device) # var = torch.zeros((f_size, f_size), dtype=type_fval, device=self.device) - - result = RAvg(weighted=self.adapt) + mean = torch.zeros(f_size, dtype=self.dtype, device=self.device) + var = torch.zeros(f_size, dtype=self.dtype, device=self.device) + result = RAvg() epoch = self.neval // self.nbatch - for itn in range(self.nitn): - mean[:] = 0 - var[:] = 0 - for _ in range(epoch): - x, log_detJ = self.sample(self.nbatch) - f_values = f(x) - batch_results = self._multiply_by_jacobian( - f_values, torch.exp(log_detJ) - ) + mean[:] = 0 + var[:] = 0 + for _ in range(epoch): + x, log_detJ = self.sample(self.nbatch) + f_values = f(x) + batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ)) - mean += torch.mean(batch_results, dim=-1) / epoch - var += torch.var(batch_results, dim=-1) / (self.neval * epoch) + mean += torch.mean(batch_results, dim=-1) / epoch + var += torch.var(batch_results, dim=-1) / (self.neval * epoch) - result.sum_neval += self.neval - result.add(gvar.gvar(mean.item(), (var**0.5).item())) + result.sum_neval += self.neval + result.add(gvar.gvar(mean.item(), (var**0.5).item())) return result def _multiply_by_jacobian(self, values, jac): @@ -105,29 +107,48 @@ def _multiply_by_jacobian(self, values, jac): return values * jac +def random_walk(dim, bounds, device, dtype, u, **kwargs): + rangebounds = bounds[:, 1] - bounds[:, 0] + step_size = kwargs.get("step_size", 0.2) + step_sizes = rangebounds * step_size + step = torch.empty(dim, device=device, dtype=dtype).uniform_(-1, 1) * step_sizes + new_u = (u + step - bounds[:, 0]) % rangebounds + bounds[:, 0] + return new_u + + +def uniform(dim, bounds, device, dtype, u, **kwargs): + rangebounds = bounds[:, 1] - bounds[:, 0] + return torch.rand_like(u) * rangebounds + bounds[:, 0] + + +def gaussian(dim, bounds, device, dtype, u, **kwargs): + mean = kwargs.get("mean", torch.zeros_like(u)) + std = kwargs.get("std", torch.ones_like(u)) + return torch.normal(mean, std) + + class MCMC(MonteCarlo): def __init__( self, - bounds: Union[List[Tuple[float, float]], np.ndarray], + bounds, q0=None, maps=None, - nitn: int = 10, neval=10000, nbatch=None, nburnin=500, device="cpu", - adapt=False, + dtype=torch.float64, ): - super().__init__(bounds, q0, maps, nitn, neval, nbatch, device, adapt) + super().__init__(bounds, q0, maps, neval, nbatch, device, dtype) self.nburnin = nburnin if maps is None: - self.maps = Affine([(0, 1)] * self.dim, device=device) + self.maps = Linear([(0, 1)] * self.dim, device=device) self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0] def __call__( self, f: Callable, - proposal_dist="uniform", + proposal_dist: Callable = uniform, thinning=1, mix_rate=0.0, **kwargs, @@ -146,84 +167,93 @@ def __call__( current_weight.masked_fill_(current_weight < epsilon, epsilon) # current_fval.masked_fill_(current_fval.abs() < epsilon, epsilon) - proposed_y = torch.empty_like(current_y) - proposed_x = torch.empty_like(current_x) - new_fval = torch.empty_like(current_fval) - new_weight = torch.empty_like(current_weight) + # proposed_y = torch.empty_like(current_y) + # proposed_x = torch.empty_like(current_x) + # new_fval = torch.empty_like(current_fval) + # new_weight = torch.empty_like(current_weight) f_size = len(current_fval) if isinstance(current_fval, (list, tuple)) else 1 - type_fval = current_fval.dtype if f_size == 1 else type(current_fval[0].dtype) - mean = torch.zeros(f_size, dtype=type_fval, device=self.device) + # type_fval = current_fval.dtype if f_size == 1 else type(current_fval[0].dtype) + # mean = torch.zeros(f_size, dtype=type_fval, device=self.device) + mean = torch.zeros(f_size, dtype=self.dtype, device=self.device) mean_ref = torch.zeros_like(mean) - var = torch.zeros(f_size, dtype=type_fval, device=self.device) + # var = torch.zeros(f_size, dtype=type_fval, device=self.device) + var = torch.zeros(f_size, dtype=self.dtype, device=self.device) var_ref = torch.zeros_like(mean) - result = RAvg(weighted=self.adapt) - result_ref = RAvg(weighted=self.adapt) + result = RAvg() + result_ref = RAvg() epoch = self.neval // self.nbatch n_meas = 0 - for itn in range(self.nitn): - for i in range(epoch): - proposed_y[:] = self._propose(current_y, proposal_dist, **kwargs) - proposed_x[:], new_jac = self.maps.forward(proposed_y) - new_jac = torch.exp(new_jac) - new_fval[:] = f(proposed_x) - new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs() + def _propose(current_y, current_fval, current_weight, current_jac): + proposed_y = proposal_dist( + self.dim, self.bounds, self.device, self.dtype, current_y, **kwargs + ) + proposed_x, new_jac = self.maps.forward(proposed_y) + new_jac = torch.exp(new_jac) + + new_fval = f(proposed_x) + new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs() - acceptance_probs = new_weight / current_weight * new_jac / current_jac + acceptance_probs = new_weight / current_weight * new_jac / current_jac - accept = ( - torch.rand(self.nbatch, dtype=torch.float64, device=self.device) - <= acceptance_probs - ) + accept = ( + torch.rand(self.nbatch, dtype=self.dtype, device=self.device) + <= acceptance_probs + ) - current_y = torch.where(accept.unsqueeze(1), proposed_y, current_y) - current_fval = torch.where(accept, new_fval, current_fval) - current_weight = torch.where(accept, new_weight, current_weight) - current_jac = torch.where(accept, new_jac, current_jac) - - if i < self.nburnin and itn == 0: - continue - elif i % thinning == 0: - n_meas += 1 - batch_results = current_fval / current_weight - - mean += torch.mean(batch_results, dim=-1) / epoch - var += torch.var(batch_results, dim=-1) / epoch - - batch_results_ref = 1 / (current_jac * current_weight) - mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch - var_ref += torch.var(batch_results_ref, dim=-1) / epoch - - result.sum_neval += self.neval - result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item())) - result_ref.sum_neval += self.nbatch - result_ref.add( - gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item()) + current_y = torch.where(accept.unsqueeze(1), proposed_y, current_y) + current_fval = torch.where(accept, new_fval, current_fval) + current_weight = torch.where(accept, new_weight, current_weight) + current_jac = torch.where(accept, new_jac, current_jac) + return current_y, current_fval, current_weight, current_jac + + for i in range(self.nburnin): + current_y, current_fval, current_weight, current_jac = _propose( + current_y, current_fval, current_weight, current_jac ) + for i in range(epoch // thinning): + for j in range(thinning): + current_y, current_fval, current_weight, current_jac = _propose( + current_y, current_fval, current_weight, current_jac + ) + n_meas += 1 + batch_results = current_fval / current_weight + + mean += torch.mean(batch_results, dim=-1) / epoch + var += torch.var(batch_results, dim=-1) / epoch + + batch_results_ref = 1 / (current_jac * current_weight) + mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch + var_ref += torch.var(batch_results_ref, dim=-1) / epoch + + result.sum_neval += self.neval + result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item())) + result_ref.sum_neval += self.nbatch + result_ref.add(gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item())) return result / result_ref * self._rangebounds.prod() - def _propose(self, u, proposal_dist, **kwargs): - if proposal_dist == "random_walk": - step_size = kwargs.get("step_size", 0.2) - step_sizes = self._rangebounds * step_size - step = ( - torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes - ) - new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[ - :, 0 - ] - return new_u - # return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0 - elif proposal_dist == "uniform": - # return torch.rand_like(u) - return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0] - # elif proposal_dist == "gaussian": - # mean = kwargs.get("mean", torch.zeros_like(u)) - # std = kwargs.get("std", torch.ones_like(u)) - # return torch.normal(mean, std) - else: - raise ValueError(f"Unknown proposal distribution: {proposal_dist}") + # def _propose(self, u, proposal_dist, **kwargs): + # if proposal_dist == "random_walk": + # step_size = kwargs.get("step_size", 0.2) + # step_sizes = self._rangebounds * step_size + # step = ( + # torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes + # ) + # new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[ + # :, 0 + # ] + # return new_u + # # return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0 + # elif proposal_dist == "uniform": + # # return torch.rand_like(u) + # return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0] + # # elif proposal_dist == "gaussian": + # # mean = kwargs.get("mean", torch.zeros_like(u)) + # # std = kwargs.get("std", torch.ones_like(u)) + # # return torch.normal(mean, std) + # else: + # raise ValueError(f"Unknown proposal distribution: {proposal_dist}") diff --git a/src/maps.py b/src/maps.py index ef4a884..2b6a6df 100644 --- a/src/maps.py +++ b/src/maps.py @@ -4,14 +4,15 @@ class Map(nn.Module): - def __init__(self, bounds, device="cpu"): + def __init__(self, bounds, device="cpu", dtype=torch.float64): super().__init__() if isinstance(bounds, (list, np.ndarray)): - self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device) + self.bounds = torch.tensor(bounds, dtype=dtype, device=device) else: raise ValueError("Unsupported map specification") self.dim = self.bounds.shape[0] self.device = device + self.dtype = dtype def forward(self, u): raise NotImplementedError("Subclasses must implement this method") @@ -21,30 +22,30 @@ def inverse(self, x): class CompositeMap(Map): - def __init__(self, maps, device="cpu"): + def __init__(self, maps, device="cpu", dtype=torch.float64): if not maps: raise ValueError("Maps can not be empty.") - super().__init__(maps[-1].bounds, device) + super().__init__(maps[-1].bounds, device, dtype) self.maps = maps def forward(self, u): - log_detJ = torch.zeros(len(u), device=u.device) + log_detJ = torch.zeros(len(u), device=u.device, dtype=self.dtype) for map in self.maps: u, log_detj = map.forward(u) log_detJ += log_detj return u, log_detJ def inverse(self, x): - log_detJ = torch.zeros(len(x), device=x.device) + log_detJ = torch.zeros(len(x), device=x.device, dtype=self.dtype) for i in range(len(self.maps) - 1, -1, -1): x, log_detj = self.maps[i].inverse(x) log_detJ += log_detj return x, log_detJ -class Affine(Map): - def __init__(self, bounds, device="cpu"): - super().__init__(bounds, device) +class Linear(Map): + def __init__(self, bounds, device="cpu", dtype=torch.float64): + super().__init__(bounds, device, dtype) self._A = self.bounds[:, 1] - self.bounds[:, 0] self._jac1 = torch.prod(self._A) @@ -58,8 +59,8 @@ def inverse(self, x): class Vegas(Map): - def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu"): - super().__init__(bounds, device) + def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu", dtype=torch.float64): + super().__init__(bounds, device, dtype) # self.nbin = nbin self.alpha = alpha if isinstance(ninc, int): @@ -68,10 +69,10 @@ def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu"): self.ninc = torch.tensor(ninc, dtype=torch.int32, device=device) self.inc = torch.empty( - self.dim, self.ninc.max(), dtype=torch.float64, device=self.device + self.dim, self.ninc.max(), dtype=self.dtype, device=self.device ) self.grid = torch.empty( - self.dim, self.ninc.max() + 1, dtype=torch.float64, device=self.device + self.dim, self.ninc.max() + 1, dtype=self.dtype, device=self.device ) for d in range(self.dim): @@ -79,7 +80,7 @@ def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu"): self.bounds[d, 0], self.bounds[d, 1], self.ninc[d] + 1, - dtype=torch.float64, + dtype=self.dtype, device=self.device, ) self.inc[d, : self.ninc[d]] = ( diff --git a/src/mc_test.py b/src/mc_test.py index 700946e..47740c0 100644 --- a/src/mc_test.py +++ b/src/mc_test.py @@ -1,6 +1,6 @@ import torch from integrators import MonteCarlo, MCMC -from maps import Vegas, Affine +from maps import Vegas, Linear from utils import set_seed, get_device set_seed(42) @@ -31,7 +31,7 @@ def half_sphere_integrand(x): dim = 2 bounds = [(-1, 1), (-1, 1)] -affine_map = Affine(bounds, device=device) +affine_map = Linear(bounds, device=device) # vegas_map = Vegas(bounds, device=device) # Monte Carlo integration