Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions MCintegration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from MCintegration.utils import get_device

# Constants for numerical stability
# Small but safe non-zero value
MINVAL = 10 ** (sys.float_info.min_10_exp + 50)
MAXVAL = 10 ** (sys.float_info.max_10_exp - 50) # Large but safe value
EPSILON = 1e-16 # Small value to ensure numerical stability
# EPSILON = sys.float_info.epsilon * 1e4 # Small value to ensure numerical stability


class BaseDistribution(nn.Module):
Expand Down Expand Up @@ -98,10 +96,8 @@ def sample(self, batch_size=1, **kwargs):
tuple: (uniform samples, log_det_jacobian=0)
"""
# torch.manual_seed(0) # test seed
u = torch.rand((batch_size, self.dim),
device=self.device, dtype=self.dtype)
log_detJ = torch.zeros(
batch_size, device=self.device, dtype=self.dtype)
u = torch.rand((batch_size, self.dim), device=self.device, dtype=self.dtype)
log_detJ = torch.zeros(batch_size, device=self.device, dtype=self.dtype)
return u, log_detJ


Expand Down Expand Up @@ -133,16 +129,14 @@ def __init__(self, A, b, device=None, dtype=torch.float32):
elif isinstance(A, torch.Tensor):
self.A = A.to(dtype=self.dtype, device=self.device)
else:
raise ValueError(
"'A' must be a list, numpy array, or torch tensor.")
raise ValueError("'A' must be a list, numpy array, or torch tensor.")

if isinstance(b, (list, np.ndarray)):
self.b = torch.tensor(b, dtype=self.dtype, device=self.device)
elif isinstance(b, torch.Tensor):
self.b = b.to(dtype=self.dtype, device=self.device)
else:
raise ValueError(
"'b' must be a list, numpy array, or torch tensor.")
raise ValueError("'b' must be a list, numpy array, or torch tensor.")

# Pre-compute determinant of Jacobian for efficiency
self._detJ = torch.prod(self.A)
Expand Down
70 changes: 26 additions & 44 deletions MCintegration/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from MCintegration.utils import get_device
import sys

TINY = 10 ** (sys.float_info.min_10_exp + 50) # Small but safe non-zero value
# TINY = 10 ** (sys.float_info.min_10_exp + 50) # Small but safe non-zero value
TINY = 1e-45


class Configuration:
Expand Down Expand Up @@ -38,14 +39,10 @@ def __init__(self, batch_size, dim, f_dim, device=None, dtype=torch.float32):
self.f_dim = f_dim
self.batch_size = batch_size
# Initialize tensors for storing samples and results
self.u = torch.empty(
(batch_size, dim), dtype=dtype, device=self.device)
self.x = torch.empty(
(batch_size, dim), dtype=dtype, device=self.device)
self.fx = torch.empty((batch_size, f_dim),
dtype=dtype, device=self.device)
self.weight = torch.empty(
(batch_size,), dtype=dtype, device=self.device)
self.u = torch.empty((batch_size, dim), dtype=dtype, device=self.device)
self.x = torch.empty((batch_size, dim), dtype=dtype, device=self.device)
self.fx = torch.empty((batch_size, f_dim), dtype=dtype, device=self.device)
self.weight = torch.empty((batch_size,), dtype=dtype, device=self.device)
self.detJ = torch.empty((batch_size,), dtype=dtype, device=self.device)


Expand Down Expand Up @@ -202,8 +199,7 @@ def __init__(self, dim, ninc=1000, device=None, dtype=torch.float32):
(self.dim,), ninc, dtype=torch.int32, device=self.device
)
elif isinstance(ninc, (list, np.ndarray)):
self.ninc = torch.tensor(
ninc, dtype=torch.int32, device=self.device)
self.ninc = torch.tensor(ninc, dtype=torch.int32, device=self.device)
elif isinstance(ninc, torch.Tensor):
self.ninc = ninc.to(dtype=torch.int32, device=self.device)
else:
Expand All @@ -223,8 +219,9 @@ def __init__(self, dim, ninc=1000, device=None, dtype=torch.float32):
self.sum_f = torch.zeros(
self.dim, self.max_ninc, dtype=self.dtype, device=self.device
)
self.n_f = torch.zeros(
self.dim, self.max_ninc, dtype=self.dtype, device=self.device
self.n_f = (
torch.zeros(self.dim, self.max_ninc, dtype=self.dtype, device=self.device)
+ TINY
)
self.avg_f = torch.ones(
(self.dim, self.max_ninc), dtype=self.dtype, device=self.device
Expand Down Expand Up @@ -308,19 +305,16 @@ def adapt(self, alpha=0.5):
"""
# Aggregate training data across distributed processes if applicable
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
self.sum_f, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(
self.n_f, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(self.sum_f, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(self.n_f, op=torch.distributed.ReduceOp.SUM)

# Initialize a new grid tensor
new_grid = torch.empty(
(self.dim, self.max_ninc + 1), dtype=self.dtype, device=self.device
)

if alpha > 0:
tmp_f = torch.empty(
self.max_ninc, dtype=self.dtype, device=self.device)
tmp_f = torch.empty(self.max_ninc, dtype=self.dtype, device=self.device)

# avg_f = torch.ones(self.inc.shape[1], dtype=self.dtype, device=self.device)
# print(self.ninc.shape, self.dim)
Expand All @@ -338,14 +332,12 @@ def adapt(self, alpha=0.5):

if alpha > 0:
# Smooth avg_f
tmp_f[0] = (7.0 * avg_f[0] + avg_f[1]
).abs() / 8.0 # Shape: ()
tmp_f[0] = (7.0 * avg_f[0] + avg_f[1]).abs() / 8.0 # Shape: ()
tmp_f[ninc - 1] = (
7.0 * avg_f[ninc - 1] + avg_f[ninc - 2]
).abs() / 8.0 # Shape: ()
tmp_f[1: ninc - 1] = (
6.0 * avg_f[1: ninc - 1] +
avg_f[: ninc - 2] + avg_f[2:ninc]
tmp_f[1 : ninc - 1] = (
6.0 * avg_f[1 : ninc - 1] + avg_f[: ninc - 2] + avg_f[2:ninc]
).abs() / 8.0

# Normalize tmp_f to ensure the sum is 1
Expand Down Expand Up @@ -393,8 +385,7 @@ def adapt(self, alpha=0.5):
) / avg_f_relevant

# Calculate the new grid points using vectorized operations
new_grid[d, 1:ninc] = grid_left + \
fractional_positions * inc_relevant
new_grid[d, 1:ninc] = grid_left + fractional_positions * inc_relevant
else:
# If alpha == 0 or no training data, retain the existing grid
new_grid[d, :] = self.grid[d, :]
Expand All @@ -407,8 +398,7 @@ def adapt(self, alpha=0.5):
self.inc.zero_() # Reset increments to zero
for d in range(self.dim):
self.inc[d, : self.ninc[d]] = (
self.grid[d, 1: self.ninc[d] + 1] -
self.grid[d, : self.ninc[d]]
self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]]
)

# Clear accumulated training data for the next adaptation cycle
Expand All @@ -432,8 +422,7 @@ def make_uniform(self):
device=self.device,
)
self.inc[d, : self.ninc[d]] = (
self.grid[d, 1: self.ninc[d] + 1] -
self.grid[d, : self.ninc[d]]
self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]]
)
self.clear()

Expand All @@ -459,8 +448,7 @@ def forward(self, u):

batch_size = u.size(0)
# Clamp iu to [0, ninc-1] to handle out-of-bounds indices
min_tensor = torch.zeros(
(1, self.dim), dtype=iu.dtype, device=self.device)
min_tensor = torch.zeros((1, self.dim), dtype=iu.dtype, device=self.device)
# Shape: (1, dim)
max_tensor = (self.ninc - 1).unsqueeze(0).to(iu.dtype)
iu_clamped = torch.clamp(iu, min=min_tensor, max=max_tensor)
Expand All @@ -471,8 +459,7 @@ def forward(self, u):
grid_gather = torch.gather(grid_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(
2
) # Shape: (batch_size, dim)
inc_gather = torch.gather(
inc_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(2)
inc_gather = torch.gather(inc_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(2)

x = grid_gather + inc_gather * du_ninc
log_detJ = (inc_gather * self.ninc).log_().sum(dim=1)
Expand All @@ -484,17 +471,15 @@ def forward(self, u):
# For each sample and dimension, set x to grid[d, ninc[d]]
# and log_detJ += log(inc[d, ninc[d]-1] * ninc[d])
boundary_grid = (
self.grid[torch.arange(
self.dim, device=self.device), self.ninc]
self.grid[torch.arange(self.dim, device=self.device), self.ninc]
.unsqueeze(0)
.expand(batch_size, -1)
)
# x = torch.where(out_of_bounds, boundary_grid, x)
x[out_of_bounds] = boundary_grid[out_of_bounds]

boundary_inc = (
self.inc[torch.arange(
self.dim, device=self.device), self.ninc - 1]
self.inc[torch.arange(self.dim, device=self.device), self.ninc - 1]
.unsqueeze(0)
.expand(batch_size, -1)
)
Expand Down Expand Up @@ -522,8 +507,7 @@ def inverse(self, x):

# Initialize output tensors
u = torch.empty_like(x)
log_detJ = torch.zeros(
batch_size, device=self.device, dtype=self.dtype)
log_detJ = torch.zeros(batch_size, device=self.device, dtype=self.dtype)

# Loop over each dimension to perform inverse mapping
for d in range(dim):
Expand All @@ -537,8 +521,7 @@ def inverse(self, x):
# Perform searchsorted to find indices where x should be inserted to maintain order
# torch.searchsorted returns indices in [0, max_ninc +1]
iu = (
torch.searchsorted(
grid_d, x[:, d].contiguous(), right=True) - 1
torch.searchsorted(grid_d, x[:, d].contiguous(), right=True) - 1
) # Shape: (batch_size,)

# Clamp indices to [0, ninc_d - 1] to ensure they are within valid range
Expand All @@ -551,8 +534,7 @@ def inverse(self, x):
inc_gather = inc_d[iu_clamped] # Shape: (batch_size,)

# Compute du: fractional part within the increment
du = (x[:, d] - grid_gather) / \
(inc_gather + TINY) # Shape: (batch_size,)
du = (x[:, d] - grid_gather) / (inc_gather + TINY) # Shape: (batch_size,)

# Compute u for dimension d
u[:, d] = (du + iu_clamped) / ninc_d # Shape: (batch_size,)
Expand Down
105 changes: 102 additions & 3 deletions MCintegration/maps_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,45 @@
import unittest
import torch

# import numpy as np
import numpy as np
from maps import Map, CompositeMap, Vegas, Configuration
from base import LinearMap


class TestConfiguration(unittest.TestCase):
def setUp(self):
self.batch_size = 5
self.dim = 3
self.f_dim = 2
self.device = "cpu"
self.dtype = torch.float64

def test_configuration_initialization(self):
config = Configuration(
batch_size=self.batch_size,
dim=self.dim,
f_dim=self.f_dim,
device=self.device,
dtype=self.dtype,
)

self.assertEqual(config.batch_size, self.batch_size)
self.assertEqual(config.dim, self.dim)
self.assertEqual(config.f_dim, self.f_dim)
self.assertEqual(config.device, self.device)

self.assertEqual(config.u.shape, (self.batch_size, self.dim))
self.assertEqual(config.x.shape, (self.batch_size, self.dim))
self.assertEqual(config.fx.shape, (self.batch_size, self.f_dim))
self.assertEqual(config.weight.shape, (self.batch_size,))
self.assertEqual(config.detJ.shape, (self.batch_size,))

self.assertEqual(config.u.dtype, self.dtype)
self.assertEqual(config.x.dtype, self.dtype)
self.assertEqual(config.fx.dtype, self.dtype)
self.assertEqual(config.weight.dtype, self.dtype)
self.assertEqual(config.detJ.dtype, self.dtype)


class TestMap(unittest.TestCase):
def setUp(self):
self.device = "cpu"
Expand All @@ -24,6 +58,35 @@ def test_inverse_not_implemented(self):
with self.assertRaises(NotImplementedError):
self.map.inverse(torch.tensor([0.5, 0.5], dtype=self.dtype))

def test_forward_with_detJ(self):
# Create a simple linear map for testing: x = u * A + b
# With A=[1, 1] and b=[0, 0], we have x = u
linear_map = LinearMap([1, 1], [0, 0], device=self.device)

# Test forward_with_detJ method
u = torch.tensor([[0.5, 0.5]], dtype=torch.float64, device=self.device)
x, detJ = linear_map.forward_with_detJ(u)

# Since it's a linear map from [0,0] to [1,1], x should equal u
self.assertTrue(torch.allclose(x, u))

# Determinant of Jacobian should be 1 for linear map with slope 1
# forward_with_detJ returns actual determinant, not log
self.assertAlmostEqual(detJ.item(), 1.0)

# Test with a different linear map: x = u * [2, 3] + [1, 1]
# So u = [0.5, 0.5] should give x = [0.5*2+1, 0.5*3+1] = [2, 2.5]
linear_map2 = LinearMap([2, 3], [1, 1], device=self.device)
u2 = torch.tensor([[0.5, 0.5]], dtype=torch.float64, device=self.device)
x2, detJ2 = linear_map2.forward_with_detJ(u2)
expected_x2 = torch.tensor(
[[2.0, 2.5]], dtype=torch.float64, device=self.device
)
self.assertTrue(torch.allclose(x2, expected_x2))

# Determinant should be 2 * 3 = 6
self.assertAlmostEqual(detJ2.item(), 6.0)


class TestCompositeMap(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -99,6 +162,32 @@ def test_initialization(self):
self.assertTrue(torch.equal(self.vegas.grid, self.init_grid))
self.assertEqual(self.vegas.inc.shape, (2, self.ninc))

def test_ninc_initialization_types(self):
# Test ninc initialization with int
vegas_int = Vegas(self.dim, ninc=5)
self.assertEqual(vegas_int.ninc.tolist(), [5, 5])

# Test ninc initialization with list
vegas_list = Vegas(self.dim, ninc=[5, 10])
self.assertEqual(vegas_list.ninc.tolist(), [5, 10])

# Test ninc initialization with numpy array
vegas_np = Vegas(self.dim, ninc=np.array([3, 7]))
self.assertEqual(vegas_np.ninc.tolist(), [3, 7])

# Test ninc initialization with torch tensor
vegas_tensor = Vegas(self.dim, ninc=torch.tensor([4, 6]))
self.assertEqual(vegas_tensor.ninc.tolist(), [4, 6])

# Test ninc initialization with invalid type
with self.assertRaises(ValueError):
Vegas(self.dim, ninc="invalid")

def test_ninc_shape_validation(self):
# Test ninc shape validation
with self.assertRaises(ValueError):
Vegas(self.dim, ninc=[1, 2, 3]) # Wrong length

def test_add_training_data(self):
# Test adding training data
self.vegas.add_training_data(self.sample)
Expand Down Expand Up @@ -137,6 +226,16 @@ def test_forward(self):
self.assertEqual(x.shape, u.shape)
self.assertEqual(log_jac.shape, (u.shape[0],))

def test_forward_with_detJ(self):
# Test forward_with_detJ transformation
u = torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float64)
x, det_jac = self.vegas.forward_with_detJ(u)
self.assertEqual(x.shape, u.shape)
self.assertEqual(det_jac.shape, (u.shape[0],))

# Determinant should be positive
self.assertTrue(torch.all(det_jac > 0))

def test_forward_out_of_bounds(self):
# Test forward transformation with out-of-bounds u values
u = torch.tensor(
Expand Down Expand Up @@ -220,4 +319,4 @@ def test_edge_cases(self):


if __name__ == "__main__":
unittest.main()
unittest.main()
3 changes: 3 additions & 0 deletions MCintegration/mc_multicpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def two_integrands(x, f):
if dist.is_initialized():
dist.destroy_process_group()

def test_mcmc_singlethread():
world_size = 1
init_process(rank=0, world_size=world_size, fn=run_mcmc, backend=backend)

def test_mcmc(world_size=2):
# Use fewer processes than CPU cores to avoid resource contention
Expand Down
Loading