Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 107 additions & 14 deletions src/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@
import gvar
import numpy as np

import os
import torch.distributed as dist
import socket


def get_ip() -> str:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]


def get_open_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


def setup():
# get IDs of reserved GPU
distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
dist.init_process_group(
backend="gloo"
) # , init_method=distributed_init_method, world_size = int(os.environ["WORLD_SIZE"]), rank = int(os.environ["RANK"]))
# init_method='env://',
# world_size=int(os.environ["WORLD_SIZE"]),
# rank=int(os.environ['SLURM_PROCID']))
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


def cleanup():
dist.destroy_process_group()


class Integrator:
"""
Expand Down Expand Up @@ -76,7 +108,7 @@ def __init__(
):
super().__init__(maps, bounds, q0, neval, nbatch, device, dtype)

def __call__(self, f: Callable, f_dim: int = 1, **kwargs):
def __call__(self, f: Callable, f_dim: int = 1, multigpu=False, **kwargs):
x, _ = self.sample(self.nbatch)
fx = torch.empty((self.nbatch, f_dim), dtype=self.dtype, device=self.device)

Expand All @@ -92,15 +124,38 @@ def __call__(self, f: Callable, f_dim: int = 1, **kwargs):
integ_values += fx / epoch

results = np.array([RAvg() for _ in range(f_dim)])
for i in range(f_dim):
_mean = integ_values[:, i].mean().item()
_var = integ_values[:, i].var().item() / self.nbatch
results[i].update(_mean, _var, self.neval)
if multigpu:
self.multi_gpu_statistic(integ_values, f_dim, results)
else:
for i in range(f_dim):
_mean = integ_values[:, i].mean().item()
_var = integ_values[:, i].var().item() / self.nbatch
results[i].update(_mean, _var, self.neval)
if f_dim == 1:
return results[0]
else:
return results

def multi_gpu_statistic(self, values, f_dim, results):
_mean = torch.zeros(f_dim, device=self.device, dtype=self.dtype)
_total_mean = torch.zeros(f_dim, device=self.device, dtype=self.dtype)
_var = torch.zeros(f_dim, device=self.device, dtype=self.dtype)
for i in range(f_dim):
_total_mean[i] = values[:, i].mean()
_mean[i] = _total_mean[i]
_var = values[:, i].var() / self.nbatch

dist.all_reduce(_total_mean, op=dist.ReduceOp.SUM)
_total_mean /= dist.get_world_size()
_var_between_batch = torch.square(_mean - _total_mean)
dist.all_reduce(_var_between_batch, op=dist.ReduceOp.SUM)
_var_between_batch /= dist.get_world_size()
dist.all_reduce(_var, op=dist.ReduceOp.SUM)
_var /= dist.get_world_size()
_var = _var + _var_between_batch
for i in range(f_dim):
results[i].update(_total_mean[i].item(), _var[i].item(), self.neval)


def random_walk(dim, bounds, device, dtype, u, **kwargs):
rangebounds = bounds[:, 1] - bounds[:, 0]
Expand Down Expand Up @@ -147,6 +202,7 @@ def __call__(
proposal_dist: Callable = uniform,
mix_rate=0.5,
meas_freq: int = 1,
multigpu=False,
**kwargs,
):
epsilon = 1e-16 # Small value to ensure numerical stability
Expand Down Expand Up @@ -211,19 +267,56 @@ def one_step(current_y, current_x, current_weight, current_jac):

results = np.array([RAvg() for _ in range(f_dim)])
results_ref = RAvg()

mean_ref = refvalues.mean().item()
var_ref = refvalues.var().item() / self.nbatch

results_ref.update(mean_ref, var_ref, self.neval)
for i in range(f_dim):
_mean = values[:, i].mean().item()
_var = values[:, i].var().item() / self.nbatch
results[i].update(_mean, _var, self.neval)
if multigpu:
self.multi_gpu_statistic(values, refvalues, results, results_ref, f_dim)
else:
mean_ref = refvalues.mean().item()
var_ref = refvalues.var().item() / self.nbatch
results_ref.update(mean_ref, var_ref, self.neval)
for i in range(f_dim):
_mean = values[:, i].mean().item()
_var = values[:, i].var().item() / self.nbatch
results[i].update(_mean, _var, self.neval)

if f_dim == 1:
res = results[0] / results_ref * self._rangebounds.prod()
result = RAvg(itn_results=[res], sum_neval=self.neval)
return result
else:
return results / results_ref * self._rangebounds.prod().item()

def multi_gpu_statistic(self, values, refvalues, results, results_ref, f_dim):
# collect multigpu statistics for values
_mean = torch.zeros(f_dim, device=self.device, dtype=self.dtype)
_total_mean = torch.zeros(f_dim, device=self.device, dtype=self.dtype)
_var = torch.zeros(f_dim, device=self.device, dtype=self.dtype)
for i in range(f_dim):
_total_mean[i] = values[:, i].mean()
_mean[i] = _total_mean[i]
_var = values[:, i].var() / self.nbatch

dist.all_reduce(_total_mean, op=dist.ReduceOp.SUM)
_total_mean /= dist.get_world_size()
_var_between_batch = torch.square(_mean - _total_mean)
dist.all_reduce(_var_between_batch, op=dist.ReduceOp.SUM)
_var_between_batch /= dist.get_world_size()
dist.all_reduce(_var, op=dist.ReduceOp.SUM)
_var /= dist.get_world_size()
_var = _var + _var_between_batch
for i in range(f_dim):
results[i].update(_total_mean[i].item(), _var[i].item(), self.neval)

# collect multigpu statistics for refvalues
_mean_ref = refvalues.mean()
_total_mean_ref = _mean_ref.clone().detach()
_var_ref = refvalues.var() / self.nbatch
dist.all_reduce(_total_mean_ref, op=dist.ReduceOp.SUM)
_total_mean_ref /= dist.get_world_size()
_var_ref_between_batch = torch.square(_mean_ref - _total_mean_ref)
dist.all_reduce(_var_ref_between_batch, op=dist.ReduceOp.SUM)
_var_ref_between_batch /= dist.get_world_size()
dist.all_reduce(_var_ref, op=dist.ReduceOp.SUM)
_var_ref /= dist.get_world_size()
_var_ref = _var_ref + _var_between_batch
results_ref.update(_mean_ref.item(), _var_ref.item(), self.neval)
# return results, results_ref
16 changes: 14 additions & 2 deletions src/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,16 @@ def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu", dtype=torch.float
self._A = self.bounds[:, 1] - self.bounds[:, 0]
self._jaclinear = torch.prod(self._A)

def train(self, nsamples, f, f_dim=1, dtype=torch.float64, epoch=5, alpha=0.5):
def train(
self,
nsamples,
f,
f_dim=1,
dtype=torch.float64,
epoch=5,
alpha=0.5,
multigpu=False,
):
q0 = Uniform(self.bounds, device=self.device, dtype=self.dtype)
u, log_detJ0 = q0.sample(nsamples)

Expand All @@ -105,7 +114,7 @@ def train(self, nsamples, f, f_dim=1, dtype=torch.float64, epoch=5, alpha=0.5):
self.add_training_data(u, f2)
self.adapt(alpha)

def add_training_data(self, u, fval):
def add_training_data(self, u, fval, multigpu=False):
"""Add training data ``f`` for ``u``-space points ``u``.

Accumulates training data for later use by ``self.adapt()``.
Expand All @@ -130,6 +139,9 @@ def add_training_data(self, u, fval):
indices = iu[:, d]
self.sum_f[d].scatter_add_(0, indices, fval.abs())
self.n_f[d].scatter_add_(0, indices, torch.ones_like(fval))
if multigpu:
torch.distributed.all_reduce(self.sum_f, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(self.n_f, op=torch.distributed.ReduceOp.SUM)

def adapt(self, alpha=0.0):
"""Adapt grid to accumulated training data.
Expand Down
Loading