Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 112 additions & 59 deletions src/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ def cleanup():
dist.destroy_process_group()


class Sample:
def __init__(self, nsample, dim, device="cpu", dtype=torch.float64):
self.dim = dim
self.nsample = nsample
self.u = torch.empty((nsample, dim), dtype=dtype, device=device)
self.x = torch.empty((nsample, dim), dtype=dtype, device=device)
self.weight = torch.empty(nsample, dtype=dtype, device=device)
self.jac = torch.empty((nsample, dim), dtype=dtype, device=device)


class Integrator:
"""
Base class for all integrators. This class is designed to handle integration tasks
Expand Down Expand Up @@ -85,18 +95,21 @@ def __init__(
def __call__(self, f: Callable, **kwargs):
raise NotImplementedError("Subclasses must implement this method")

def sample(self, nsample, **kwargs):
u, log_detJ = self.q0.sample(nsample)
def sample(self, sample, **kwargs):
sample.u, sample.jac = self.q0.sample(sample.nsample)
if not self.maps:
return u, log_detJ
sample.x[:] = sample.u
else:
u, log_detj = self.maps.forward(u)
return u, log_detJ + log_detj
sample.x[:], log_detj = self.maps.forward(sample.u)
sample.jac += log_detj
sample.jac.exp_()


class MonteCarlo(Integrator):
def __init__(
self,
f: Callable,
f_dim=1,
maps=None,
bounds=None,
q0=None,
Expand All @@ -106,28 +119,32 @@ def __init__(
dtype=torch.float64,
):
super().__init__(maps, bounds, q0, neval, nbatch, device, dtype)
self.f = f
self.f_dim = f_dim

def __call__(self, f: Callable, f_dim: int = 1, multigpu=False, **kwargs):
def __call__(self, multigpu=False, **kwargs):
if multigpu:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1

x, _ = self.sample(self.nbatch)
fx = torch.empty((self.nbatch, f_dim), dtype=self.dtype, device=self.device)
sample = Sample(self.nbatch, self.dim, self.device, self.dtype)
self.sample(sample)
fx = torch.empty(
(self.nbatch, self.f_dim), dtype=self.dtype, device=self.device
)

epoch = self.neval // self.nbatch
integ_values = torch.zeros(
(self.nbatch, f_dim), dtype=self.dtype, device=self.device
(self.nbatch, self.f_dim), dtype=self.dtype, device=self.device
)
results = np.array([RAvg() for _ in range(f_dim)])
results = np.array([RAvg() for _ in range(self.f_dim)])

for _ in range(epoch):
x, log_detJ = self.sample(self.nbatch)
f(x, fx)
fx.mul_(log_detJ.exp_().unsqueeze_(1))
self.sample(sample)
self.f(sample.x, fx)
fx.mul_(sample.jac.unsqueeze_(1))
integ_values += fx / epoch

results = self.statistics(integ_values, results, rank, world_size)
Expand Down Expand Up @@ -198,12 +215,15 @@ def gaussian(dim, bounds, device, dtype, u, **kwargs):
return torch.normal(mean, std)


class MCMC(MonteCarlo):
class MCMC(Integrator):
def __init__(
self,
f: Callable,
f_dim: int = 1,
maps=None,
bounds=None,
q0=None,
proposal_dist=None,
neval: int = 10000,
nbatch: int = None,
nburnin: int = 500,
Expand All @@ -212,17 +232,69 @@ def __init__(
):
super().__init__(maps, bounds, q0, neval, nbatch, device, dtype)
self.nburnin = nburnin

self.f = f
self.f_dim = f_dim
self.fx = torch.empty(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.fx has been given in init().

(self.nbatch, self.f_dim), dtype=self.dtype, device=self.device
)
if not proposal_dist:
self.proposal_dist = uniform
else:
if not isinstance(proposal_dist, Callable):
raise TypeError("proposal_dist must be a callable function.")
self.proposal_dist = proposal_dist
# If no transformation maps are provided, use a linear map as default
if maps is None:
self.maps = Linear([(0, 1)] * self.dim, device=device)
self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0]

def __setattr__(self, __name, __value):
super().__setattr__(__name, __value)
if __name == "f_dim":
super().__setattr__(
"fx",
torch.empty(
(self.nbatch, self.f_dim), dtype=self.dtype, device=self.device
),
)

def sample(self, sample, niter=10, mix_rate=0, **kwargs):
for _ in range(niter):
self.metropolis_hastings(sample, mix_rate, **kwargs)

def metropolis_hastings(self, sample, mix_rate, **kwargs):
proposed_y = self.proposal_dist(
self.dim, self.bounds, self.device, self.dtype, sample.u, **kwargs
)
proposed_x, new_jac = self.maps.forward(proposed_y)
new_jac.exp_()

new_weight = (
mix_rate / new_jac + (1 - mix_rate) * self.f(proposed_x, self.fx).abs_()
)
new_weight.masked_fill_(new_weight < EPSILON, EPSILON)

acceptance_probs = new_weight / sample.weight * new_jac / sample.jac

accept = (
torch.rand(self.nbatch, dtype=torch.float64, device=self.device)
<= acceptance_probs
)

accept_expanded = accept.unsqueeze(1)
sample.u.mul_(~accept_expanded).add_(proposed_y * accept_expanded)
sample.x.mul_(~accept_expanded).add_(proposed_x * accept_expanded)
sample.weight.mul_(~accept).add_(new_weight * accept)
sample.jac.mul_(~accept).add_(new_jac * accept)

# def sample(self, nsample, **kwargs):
# return

def __call__(
self,
f: Callable,
f_dim: int = 1,
proposal_dist: Callable = uniform,
# f: Callable,
# f_dim: int = 1,
# proposal_dist: Callable = uniform,
mix_rate=0.5,
meas_freq: int = 1,
multigpu=False,
Expand All @@ -235,59 +307,40 @@ def __call__(
rank = 0
world_size = 1

# self.fx = torch.empty(
# (self.nbatch, self.f_dim), dtype=self.dtype, device=self.device
# )
sample = Sample(self.nbatch, self.dim, self.device, self.dtype)
epoch = self.neval // self.nbatch
current_y, current_jac = self.q0.sample(self.nbatch)
current_x, detJ = self.maps.forward(current_y)
current_jac += detJ
current_jac.exp_()
fx = torch.empty((self.nbatch, f_dim), dtype=self.dtype, device=self.device)

current_weight = (
mix_rate / current_jac + (1 - mix_rate) * f(current_x, fx).abs_()
sample.u, sample.jac = self.q0.sample(self.nbatch)
sample.x, detJ = self.maps.forward(sample.u)
sample.jac += detJ
sample.jac.exp_()
sample.weight = (
mix_rate / sample.jac + (1 - mix_rate) * self.f(sample.x, self.fx).abs_()
)
current_weight.masked_fill_(current_weight < EPSILON, EPSILON)
sample.weight.masked_fill_(sample.weight < EPSILON, EPSILON)

n_meas = epoch // meas_freq

def one_step(current_y, current_x, current_weight, current_jac):
proposed_y = proposal_dist(
self.dim, self.bounds, self.device, self.dtype, current_y, **kwargs
)
proposed_x, new_jac = self.maps.forward(proposed_y)
new_jac.exp_()

new_weight = mix_rate / new_jac + (1 - mix_rate) * f(proposed_x, fx).abs_()
new_weight.masked_fill_(new_weight < EPSILON, EPSILON)

acceptance_probs = new_weight / current_weight * new_jac / current_jac

accept = (
torch.rand(self.nbatch, dtype=torch.float64, device=self.device)
<= acceptance_probs
)

accept_expanded = accept.unsqueeze(1)
current_y.mul_(~accept_expanded).add_(proposed_y * accept_expanded)
current_x.mul_(~accept_expanded).add_(proposed_x * accept_expanded)
current_weight.mul_(~accept).add_(new_weight * accept)
current_jac.mul_(~accept).add_(new_jac * accept)

for _ in range(self.nburnin):
one_step(current_y, current_x, current_weight, current_jac)
self.metropolis_hastings(sample, mix_rate, **kwargs)

values = torch.zeros((self.nbatch, f_dim), dtype=self.dtype, device=self.device)
values = torch.zeros(
(self.nbatch, self.f_dim), dtype=self.dtype, device=self.device
)
refvalues = torch.zeros(self.nbatch, dtype=self.dtype, device=self.device)
results_unnorm = np.array([RAvg() for _ in range(f_dim)])
results_unnorm = np.array([RAvg() for _ in range(self.f_dim)])
results_ref = RAvg()

for _ in range(n_meas):
for _ in range(meas_freq):
one_step(current_y, current_x, current_weight, current_jac)
f(current_x, fx)
self.metropolis_hastings(sample, mix_rate, **kwargs)
self.f(sample.x, self.fx)

fx.div_(current_weight.unsqueeze(1))
values += fx / n_meas
refvalues += 1 / (current_jac * current_weight) / n_meas
self.fx.div_(sample.weight.unsqueeze(1))
values += self.fx / n_meas
refvalues += 1 / (sample.jac * sample.weight) / n_meas

results = self.statistics(
values, refvalues, results_unnorm, results_ref, rank, world_size
Expand Down
Loading