From e915bd6b4c85ef7225977480e9fb5cef2a1d70ed Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Wed, 20 Nov 2024 23:23:30 -0500 Subject: [PATCH 1/2] refactor --- src/integrators.py | 165 +++++++++++++++++++++++++++++---------------- src/mc_test.py | 74 +++++++++++++------- 2 files changed, 156 insertions(+), 83 deletions(-) diff --git a/src/integrators.py b/src/integrators.py index 0169468..ac16daf 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -38,6 +38,16 @@ def cleanup(): dist.destroy_process_group() +class Sample: + def __init__(self, nsample, dim, device="cpu", dtype=torch.float64): + self.dim = dim + self.nsample = nsample + self.u = torch.empty((nsample, dim), dtype=dtype, device=device) + self.x = torch.empty((nsample, dim), dtype=dtype, device=device) + self.weight = torch.empty(nsample, dtype=dtype, device=device) + self.jac = torch.empty((nsample, dim), dtype=dtype, device=device) + + class Integrator: """ Base class for all integrators. This class is designed to handle integration tasks @@ -85,18 +95,21 @@ def __init__( def __call__(self, f: Callable, **kwargs): raise NotImplementedError("Subclasses must implement this method") - def sample(self, nsample, **kwargs): - u, log_detJ = self.q0.sample(nsample) + def sample(self, sample, **kwargs): + sample.u, sample.jac = self.q0.sample(sample.nsample) if not self.maps: - return u, log_detJ + sample.x = sample.u else: - u, log_detj = self.maps.forward(u) - return u, log_detJ + log_detj + sample.x, log_detj = self.maps.forward(sample.u) + sample.jac += log_detj + sample.jac.exp_() class MonteCarlo(Integrator): def __init__( self, + f: Callable, + f_dim=1, maps=None, bounds=None, q0=None, @@ -106,28 +119,32 @@ def __init__( dtype=torch.float64, ): super().__init__(maps, bounds, q0, neval, nbatch, device, dtype) + self.f = f + self.f_dim = f_dim - def __call__(self, f: Callable, f_dim: int = 1, multigpu=False, **kwargs): + def __call__(self, multigpu=False, **kwargs): if multigpu: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 - - x, _ = self.sample(self.nbatch) - fx = torch.empty((self.nbatch, f_dim), dtype=self.dtype, device=self.device) + sample = Sample(self.nbatch, self.dim, self.device, self.dtype) + self.sample(sample) + fx = torch.empty( + (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device + ) epoch = self.neval // self.nbatch integ_values = torch.zeros( - (self.nbatch, f_dim), dtype=self.dtype, device=self.device + (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device ) - results = np.array([RAvg() for _ in range(f_dim)]) + results = np.array([RAvg() for _ in range(self.f_dim)]) for _ in range(epoch): - x, log_detJ = self.sample(self.nbatch) - f(x, fx) - fx.mul_(log_detJ.exp_().unsqueeze_(1)) + self.sample(sample) + self.f(sample.x, fx) + fx.mul_(sample.jac.unsqueeze_(1)) integ_values += fx / epoch results = self.statistics(integ_values, results, rank, world_size) @@ -198,12 +215,15 @@ def gaussian(dim, bounds, device, dtype, u, **kwargs): return torch.normal(mean, std) -class MCMC(MonteCarlo): +class MCMC(Integrator): def __init__( self, + f: Callable, + f_dim: int = 1, maps=None, bounds=None, q0=None, + proposal_dist=None, neval: int = 10000, nbatch: int = None, nburnin: int = 500, @@ -212,17 +232,59 @@ def __init__( ): super().__init__(maps, bounds, q0, neval, nbatch, device, dtype) self.nburnin = nburnin - + self.f = f + self.f_dim = f_dim + self.fx = torch.empty( + (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device + ) + if not proposal_dist: + self.proposal_dist = uniform + else: + if not isinstance(proposal_dist, Callable): + raise TypeError("proposal_dist must be a callable function.") + self.proposal_dist = proposal_dist # If no transformation maps are provided, use a linear map as default if maps is None: self.maps = Linear([(0, 1)] * self.dim, device=device) self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0] + def sample(self, sample, niter=10, mix_rate=0, **kwargs): + for _ in range(niter): + self.metropolis_hastings(sample, mix_rate, **kwargs) + + def metropolis_hastings(self, sample, mix_rate, **kwargs): + proposed_y = self.proposal_dist( + self.dim, self.bounds, self.device, self.dtype, sample.u, **kwargs + ) + proposed_x, new_jac = self.maps.forward(proposed_y) + new_jac.exp_() + + new_weight = ( + mix_rate / new_jac + (1 - mix_rate) * self.f(proposed_x, self.fx).abs_() + ) + new_weight.masked_fill_(new_weight < EPSILON, EPSILON) + + acceptance_probs = new_weight / sample.weight * new_jac / sample.jac + + accept = ( + torch.rand(self.nbatch, dtype=torch.float64, device=self.device) + <= acceptance_probs + ) + + accept_expanded = accept.unsqueeze(1) + sample.u.mul_(~accept_expanded).add_(proposed_y * accept_expanded) + sample.x.mul_(~accept_expanded).add_(proposed_x * accept_expanded) + sample.weight.mul_(~accept).add_(new_weight * accept) + sample.jac.mul_(~accept).add_(new_jac * accept) + + # def sample(self, nsample, **kwargs): + # return + def __call__( self, - f: Callable, - f_dim: int = 1, - proposal_dist: Callable = uniform, + # f: Callable, + # f_dim: int = 1, + # proposal_dist: Callable = uniform, mix_rate=0.5, meas_freq: int = 1, multigpu=False, @@ -235,59 +297,44 @@ def __call__( rank = 0 world_size = 1 + self.fx = torch.empty( + (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device + ) + sample = Sample(self.nbatch, self.dim, self.device, self.dtype) epoch = self.neval // self.nbatch - current_y, current_jac = self.q0.sample(self.nbatch) - current_x, detJ = self.maps.forward(current_y) - current_jac += detJ - current_jac.exp_() - fx = torch.empty((self.nbatch, f_dim), dtype=self.dtype, device=self.device) - - current_weight = ( - mix_rate / current_jac + (1 - mix_rate) * f(current_x, fx).abs_() + sample.u, sample.jac = self.q0.sample(self.nbatch) + sample.x, detJ = self.maps.forward(sample.u) + sample.jac += detJ + sample.jac.exp_() + # self.fx = torch.empty( + # (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device + # ) + + sample.weight = ( + mix_rate / sample.jac + (1 - mix_rate) * self.f(sample.x, self.fx).abs_() ) - current_weight.masked_fill_(current_weight < EPSILON, EPSILON) + sample.weight.masked_fill_(sample.weight < EPSILON, EPSILON) n_meas = epoch // meas_freq - def one_step(current_y, current_x, current_weight, current_jac): - proposed_y = proposal_dist( - self.dim, self.bounds, self.device, self.dtype, current_y, **kwargs - ) - proposed_x, new_jac = self.maps.forward(proposed_y) - new_jac.exp_() - - new_weight = mix_rate / new_jac + (1 - mix_rate) * f(proposed_x, fx).abs_() - new_weight.masked_fill_(new_weight < EPSILON, EPSILON) - - acceptance_probs = new_weight / current_weight * new_jac / current_jac - - accept = ( - torch.rand(self.nbatch, dtype=torch.float64, device=self.device) - <= acceptance_probs - ) - - accept_expanded = accept.unsqueeze(1) - current_y.mul_(~accept_expanded).add_(proposed_y * accept_expanded) - current_x.mul_(~accept_expanded).add_(proposed_x * accept_expanded) - current_weight.mul_(~accept).add_(new_weight * accept) - current_jac.mul_(~accept).add_(new_jac * accept) - for _ in range(self.nburnin): - one_step(current_y, current_x, current_weight, current_jac) + self.metropolis_hastings(sample, mix_rate, **kwargs) - values = torch.zeros((self.nbatch, f_dim), dtype=self.dtype, device=self.device) + values = torch.zeros( + (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device + ) refvalues = torch.zeros(self.nbatch, dtype=self.dtype, device=self.device) - results_unnorm = np.array([RAvg() for _ in range(f_dim)]) + results_unnorm = np.array([RAvg() for _ in range(self.f_dim)]) results_ref = RAvg() for _ in range(n_meas): for _ in range(meas_freq): - one_step(current_y, current_x, current_weight, current_jac) - f(current_x, fx) + self.metropolis_hastings(sample, mix_rate, **kwargs) + self.f(sample.x, self.fx) - fx.div_(current_weight.unsqueeze(1)) - values += fx / n_meas - refvalues += 1 / (current_jac * current_weight) / n_meas + self.fx.div_(sample.weight.unsqueeze(1)) + values += self.fx / n_meas + refvalues += 1 / (sample.jac * sample.weight) / n_meas results = self.statistics( values, refvalues, results_unnorm, results_ref, rank, world_size diff --git a/src/mc_test.py b/src/mc_test.py index 48b7d24..b2e2469 100644 --- a/src/mc_test.py +++ b/src/mc_test.py @@ -52,73 +52,84 @@ def sharp_integrands(x, f): # Start Monte Carlo integration, including plain-MC, MCMC, vegas, and vegas-MCMC mc_integrator = MonteCarlo( + f=unit_circle_integrand, bounds=bounds, neval=n_eval, nbatch=n_batch, device=device, ) mcmc_integrator = MCMC( - bounds=bounds, neval=n_eval, nbatch=n_batch, nburnin=n_therm, device=device + f=unit_circle_integrand, + bounds=bounds, + neval=n_eval, + nbatch=n_batch, + nburnin=n_therm, + device=device, ) print("Calculate the area of the unit circle f(x1, x2) in the bounds [-1, 1]^2...") -res = mc_integrator(unit_circle_integrand) +res = mc_integrator() print("Plain MC Integral results: ", res) -res = mcmc_integrator(unit_circle_integrand, mix_rate=0.5) +res = mcmc_integrator(mix_rate=0.5) print("MCMC Integral results: ", res) vegas_map.train(20000, unit_circle_integrand, alpha=0.5) vegas_integrator = MonteCarlo( + f=unit_circle_integrand, maps=vegas_map, neval=n_eval, nbatch=n_batch, # nbatch=n_eval, device=device, ) -res = vegas_integrator(unit_circle_integrand) +res = vegas_integrator() print("VEGAS Integral results: ", res) vegasmcmc_integrator = MCMC( + f=unit_circle_integrand, maps=vegas_map, neval=n_eval, nbatch=n_batch, nburnin=n_therm, device=device, ) -res = vegasmcmc_integrator(unit_circle_integrand, mix_rate=0.5) +res = vegasmcmc_integrator(mix_rate=0.5) print("VEGAS-MCMC Integral results: ", res, "\n") print( r"Calculate the integral g(x1, x2) = $2 \max(1-(x_1^2+x_2^2), 0)$ in the bounds [-1, 1]^2..." ) - -res = mc_integrator(half_sphere_integrand) +mc_integrator.f = half_sphere_integrand +res = mc_integrator() print("Plain MC Integral results: ", res) - -res = mcmc_integrator(half_sphere_integrand, mix_rate=0.5) +mcmc_integrator.f = half_sphere_integrand +res = mcmc_integrator(mix_rate=0.5) print("MCMC Integral results:", res) vegas_map.make_uniform() # train the vegas map vegas_map.train(20000, half_sphere_integrand, epoch=10, alpha=0.5) - -res = vegas_integrator(half_sphere_integrand) +vegas_integrator.f = half_sphere_integrand +res = vegas_integrator() print("VEGAS Integral results: ", res) - -res = vegasmcmc_integrator(half_sphere_integrand, mix_rate=0.5) +vegasmcmc_integrator.f = half_sphere_integrand +res = vegasmcmc_integrator(mix_rate=0.5) print("VEGAS-MCMC Integral results: ", res) print("\nCalculate the integral [f(x1, x2), g(x1, x2)/2] in the bounds [-1, 1]^2") # Two integrands -res = mc_integrator(two_integrands, f_dim=2) +mc_integrator.f = two_integrands +mc_integrator.f_dim = 2 +res = mc_integrator() print("Plain MC Integral results:") print(" Integral 1: ", res[0]) print(" Integral 2: ", res[1]) - -res = mcmc_integrator(two_integrands, f_dim=2, mix_rate=0.5) +mcmc_integrator.f = two_integrands +mcmc_integrator.f_dim = 2 +res = mcmc_integrator(mix_rate=0.5) print("MCMC Integral results:") print(f" Integral 1: ", res[0]) print(f" Integral 2: ", res[1]) @@ -126,12 +137,15 @@ def sharp_integrands(x, f): # print("VEAGS map is trained for g(x1, x2)") vegas_map.make_uniform() vegas_map.train(20000, two_integrands, f_dim=2, epoch=10, alpha=0.5) -res = vegas_integrator(two_integrands, f_dim=2) +vegas_integrator.f = two_integrands +vegas_integrator.f_dim = 2 +res = vegas_integrator() print("VEGAS Integral results:") print(" Integral 1: ", res[0]) print(" Integral 2: ", res[1]) - -res = vegasmcmc_integrator(two_integrands, f_dim=2, mix_rate=0.5) +vegasmcmc_integrator.f = two_integrands +vegasmcmc_integrator.f_dim = 2 +res = vegasmcmc_integrator(mix_rate=0.5) print("VEGAS-MCMC Integral results:") print(" Integral 1: ", res[0]) print(" Integral 2: ", res[1]) @@ -141,6 +155,8 @@ def sharp_integrands(x, f): bounds = [(0, 1)] * 4 mc_integrator = MonteCarlo( + f=sharp_integrands, + f_dim=3, bounds=bounds, neval=n_eval, nbatch=n_batch, @@ -148,10 +164,16 @@ def sharp_integrands(x, f): device=device, ) mcmc_integrator = MCMC( - bounds=bounds, neval=n_eval, nbatch=n_batch, nburnin=n_therm, device=device + f=sharp_integrands, + f_dim=3, + bounds=bounds, + neval=n_eval, + nbatch=n_batch, + nburnin=n_therm, + device=device, ) print("Plain MC Integral results:") -res = mc_integrator(sharp_integrands, f_dim=3) +res = mc_integrator() print( " I[0] =", res[0], @@ -163,7 +185,7 @@ def sharp_integrands(x, f): res[1] / res[0], ) print("MCMC Integral results:") -res = mcmc_integrator(sharp_integrands, f_dim=3, mix_rate=0.5) +res = mcmc_integrator(mix_rate=0.5) print( " I[0] =", res[0], @@ -181,13 +203,15 @@ def sharp_integrands(x, f): print("VEGAS Integral results:") vegas_integrator = MonteCarlo( + f=sharp_integrands, + f_dim=3, maps=vegas_map, neval=n_eval, nbatch=n_batch, # nbatch=n_eval, device=device, ) -res = vegas_integrator(sharp_integrands, f_dim=3) +res = vegas_integrator() print( " I[0] =", res[0], @@ -201,13 +225,15 @@ def sharp_integrands(x, f): print("VEGAS-MCMC Integral results:") vegasmcmc_integrator = MCMC( + f=sharp_integrands, + f_dim=3, maps=vegas_map, neval=n_eval, nbatch=n_batch, nburnin=n_therm, device=device, ) -res = vegasmcmc_integrator(sharp_integrands, f_dim=3, mix_rate=0.5) +res = vegasmcmc_integrator(mix_rate=0.5) print( " I[0] =", res[0], From abc1312a7f2a307a3f789a4f3090d9584832291e Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Mon, 25 Nov 2024 00:31:24 -0500 Subject: [PATCH 2/2] address github comments --- src/integrators.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/integrators.py b/src/integrators.py index ac16daf..cf83f2a 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -98,9 +98,9 @@ def __call__(self, f: Callable, **kwargs): def sample(self, sample, **kwargs): sample.u, sample.jac = self.q0.sample(sample.nsample) if not self.maps: - sample.x = sample.u + sample.x[:] = sample.u else: - sample.x, log_detj = self.maps.forward(sample.u) + sample.x[:], log_detj = self.maps.forward(sample.u) sample.jac += log_detj sample.jac.exp_() @@ -248,6 +248,16 @@ def __init__( self.maps = Linear([(0, 1)] * self.dim, device=device) self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0] + def __setattr__(self, __name, __value): + super().__setattr__(__name, __value) + if __name == "f_dim": + super().__setattr__( + "fx", + torch.empty( + (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device + ), + ) + def sample(self, sample, niter=10, mix_rate=0, **kwargs): for _ in range(niter): self.metropolis_hastings(sample, mix_rate, **kwargs) @@ -297,19 +307,15 @@ def __call__( rank = 0 world_size = 1 - self.fx = torch.empty( - (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device - ) + # self.fx = torch.empty( + # (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device + # ) sample = Sample(self.nbatch, self.dim, self.device, self.dtype) epoch = self.neval // self.nbatch sample.u, sample.jac = self.q0.sample(self.nbatch) sample.x, detJ = self.maps.forward(sample.u) sample.jac += detJ sample.jac.exp_() - # self.fx = torch.empty( - # (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device - # ) - sample.weight = ( mix_rate / sample.jac + (1 - mix_rate) * self.f(sample.x, self.fx).abs_() )