Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions MCintegration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import nn
import numpy as np
import sys
from MCintegration.utils import get_device

MINVAL = 10 ** (sys.float_info.min_10_exp + 50)
MAXVAL = 10 ** (sys.float_info.max_10_exp - 50)
Expand All @@ -14,7 +15,7 @@ class BaseDistribution(nn.Module):
Parameters do not depend of target variable (as is the case for a VAE encoder)
"""

def __init__(self, dim, device="cpu", dtype=torch.float64):
def __init__(self, dim, device="cpu", dtype=torch.float32):
super().__init__()
self.dtype = dtype
self.dim = dim
Expand Down Expand Up @@ -42,11 +43,48 @@ class Uniform(BaseDistribution):
Multivariate uniform distribution
"""

def __init__(self, dim, device="cpu", dtype=torch.float64):
def __init__(self, dim, device="cpu", dtype=torch.float32):
super().__init__(dim, device, dtype)

def sample(self, batch_size=1, **kwargs):
# torch.manual_seed(0) # test seed
u = torch.rand((batch_size, self.dim), device=self.device, dtype=self.dtype)
log_detJ = torch.zeros(batch_size, device=self.device, dtype=self.dtype)
return u, log_detJ


class LinearMap(nn.Module):
def __init__(self, A, b, device=None, dtype=torch.float32):
if device is None:
self.device = get_device()
else:
self.device = device
self.dtype = dtype

assert len(A) == len(b), "A and b must have the same dimension."
if isinstance(A, (list, np.ndarray)):
self.A = torch.tensor(A, dtype=self.dtype, device=self.device)
elif isinstance(A, torch.Tensor):
self.A = A.to(dtype=self.dtype, device=self.device)
else:
raise ValueError("'A' must be a list, numpy array, or torch tensor.")

if isinstance(b, (list, np.ndarray)):
self.b = torch.tensor(b, dtype=self.dtype, device=self.device)
elif isinstance(b, torch.Tensor):
self.b = b.to(dtype=self.dtype, device=self.device)
else:
raise ValueError("'b' must be a list, numpy array, or torch tensor.")

self._detJ = torch.prod(self.A)

def forward(self, u):
return u * self.A + self.b, torch.log(self._detJ.repeat(u.shape[0]))

def forward_with_detJ(self, u):
u, detJ = self.forward(u)
detJ.exp_()
return u, detJ

def inverse(self, x):
return (x - self.b) / self.A, torch.log(self._detJ.repeat(x.shape[0]))
149 changes: 96 additions & 53 deletions MCintegration/base_test.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,25 @@
# base_test.py
import unittest
import torch
import numpy as np
from base import BaseDistribution, Uniform
from base import BaseDistribution, Uniform, LinearMap


class TestBaseDistribution(unittest.TestCase):
def setUp(self):
# Common setup for all tests
self.bounds_list = [[0.0, 1.0], [2.0, 3.0]]
self.bounds_tensor = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
self.dim = 2
self.device = "cpu"
self.dtype = torch.float64
self.dtype = torch.float32

def test_init_with_list(self):
base_dist = BaseDistribution(self.bounds_list, self.device, self.dtype)
self.assertEqual(base_dist.bounds.tolist(), self.bounds_list)
self.assertEqual(base_dist.dim, 2)
def test_init(self):
base_dist = BaseDistribution(self.dim, self.device, self.dtype)
self.assertEqual(base_dist.dim, self.dim)
self.assertEqual(base_dist.device, self.device)
self.assertEqual(base_dist.dtype, self.dtype)

def test_init_with_tensor(self):
base_dist = BaseDistribution(self.bounds_tensor, self.device, self.dtype)
self.assertTrue(torch.equal(base_dist.bounds, self.bounds_tensor))
self.assertEqual(base_dist.dim, 2)
self.assertEqual(base_dist.device, self.device)
self.assertEqual(base_dist.dtype, self.dtype)

def test_init_with_invalid_bounds(self):
with self.assertRaises(ValueError):
BaseDistribution("invalid_bounds", self.device, self.dtype)

def test_sample_not_implemented(self):
base_dist = BaseDistribution(self.bounds_list, self.device, self.dtype)
base_dist = BaseDistribution(self.dim, self.device, self.dtype)
with self.assertRaises(NotImplementedError):
base_dist.sample()

Expand All @@ -43,61 +31,116 @@ def tearDown(self):
class TestUniform(unittest.TestCase):
def setUp(self):
# Common setup for all tests
self.bounds_list = [[0.0, 1.0], [2.0, 3.0]]
self.bounds_tensor = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
self.dim = 2
self.device = "cpu"
self.dtype = torch.float64
self.uniform_dist = Uniform(self.bounds_list, self.device, self.dtype)
self.dtype = torch.float32
self.uniform_dist = Uniform(self.dim, self.device, self.dtype)

def test_init_with_list(self):
self.assertEqual(self.uniform_dist.bounds.tolist(), self.bounds_list)
self.assertEqual(self.uniform_dist.dim, 2)
def test_init(self):
self.assertEqual(self.uniform_dist.dim, self.dim)
self.assertEqual(self.uniform_dist.device, self.device)
self.assertEqual(self.uniform_dist.dtype, self.dtype)

def test_init_with_tensor(self):
uniform_dist = Uniform(self.bounds_tensor, self.device, self.dtype)
self.assertTrue(torch.equal(uniform_dist.bounds, self.bounds_tensor))
self.assertEqual(uniform_dist.dim, 2)
self.assertEqual(uniform_dist.device, self.device)
self.assertEqual(uniform_dist.dtype, self.dtype)

def test_sample_within_bounds(self):
batch_size = 1000
samples, log_detJ = self.uniform_dist.sample(batch_size)
self.assertEqual(samples.shape, (batch_size, 2))
self.assertTrue(torch.all(samples[:, 0] >= 0.0))
self.assertTrue(torch.all(samples[:, 0] <= 1.0))
self.assertTrue(torch.all(samples[:, 1] >= 2.0))
self.assertTrue(torch.all(samples[:, 1] <= 3.0))
self.assertEqual(samples.shape, (batch_size, self.dim))
self.assertTrue(torch.all(samples >= 0.0))
self.assertTrue(torch.all(samples <= 1.0))
self.assertEqual(log_detJ.shape, (batch_size,))
self.assertTrue(
torch.allclose(
log_detJ, torch.tensor([np.log(1.0) + np.log(1.0)] * batch_size)
)
)
self.assertTrue(torch.allclose(log_detJ, torch.tensor([0.0] * batch_size)))

def test_sample_with_single_sample(self):
samples, log_detJ = self.uniform_dist.sample(1)
self.assertEqual(samples.shape, (1, 2))
self.assertTrue(torch.all(samples[:, 0] >= 0.0))
self.assertTrue(torch.all(samples[:, 0] <= 1.0))
self.assertTrue(torch.all(samples[:, 1] >= 2.0))
self.assertTrue(torch.all(samples[:, 1] <= 3.0))
self.assertEqual(samples.shape, (1, self.dim))
self.assertTrue(torch.all(samples >= 0.0))
self.assertTrue(torch.all(samples <= 1.0))
self.assertEqual(log_detJ.shape, (1,))
self.assertTrue(
torch.allclose(log_detJ, torch.tensor([np.log(1.0) + np.log(1.0)]))
)
self.assertTrue(torch.allclose(log_detJ, torch.tensor([0.0])))

def test_sample_with_zero_samples(self):
samples, log_detJ = self.uniform_dist.sample(0)
self.assertEqual(samples.shape, (0, 2))
self.assertEqual(samples.shape, (0, self.dim))
self.assertEqual(log_detJ.shape, (0,))

def tearDown(self):
# Common teardown for all tests
pass


class TestLinearMap(unittest.TestCase):
def setUp(self):
_A = torch.tensor([2.0, 3.0], dtype=torch.float32)
_b = torch.tensor([1.0, 2.0], dtype=torch.float32)
self.linear_map = LinearMap(_A, _b, dtype=torch.float32)
self.A = self.linear_map.A
self.b = self.linear_map.b
self.device = self.linear_map.device

def test_forward(self):
u = torch.tensor(
[[1.0, 1.0], [2.0, 2.0]], dtype=torch.float32, device=self.device
)
expected_x = torch.tensor(
[[3.0, 5.0], [5.0, 8.0]], dtype=torch.float32, device=self.device
)
expected_detJ = torch.tensor(
[np.log(6.0), np.log(6.0)], dtype=torch.float32, device=self.device
)

x, detJ = self.linear_map.forward(u)

self.assertTrue(torch.allclose(x, expected_x))
self.assertTrue(torch.allclose(detJ, expected_detJ))

def test_forward_with_detJ(self):
u = torch.tensor(
[[1.0, 1.0], [2.0, 2.0]], dtype=torch.float32, device=self.device
)
expected_x = torch.tensor(
[[3.0, 5.0], [5.0, 8.0]], dtype=torch.float32, device=self.device
)
expected_detJ = torch.tensor(
[6.0, 6.0], dtype=torch.float32, device=self.device
)

x, detJ = self.linear_map.forward_with_detJ(u)

self.assertTrue(torch.allclose(x, expected_x))
self.assertTrue(torch.allclose(detJ, expected_detJ))

def test_inverse(self):
x = torch.tensor(
[[3.0, 5.0], [5.0, 8.0]], dtype=torch.float32, device=self.device
)
expected_u = torch.tensor(
[[1.0, 1.0], [2.0, 2.0]], dtype=torch.float32, device=self.device
)
expected_detJ = torch.tensor(
[np.log(6.0), np.log(6.0)], dtype=torch.float32, device=self.device
)

u, detJ = self.linear_map.inverse(x)

self.assertTrue(torch.allclose(u, expected_u))
self.assertTrue(torch.allclose(detJ, expected_detJ))

def test_init_with_list(self):
A_list = [2.0, 3.0]
b_list = [1.0, 2.0]
linear_map = LinearMap(A_list, b_list, dtype=torch.float32)

self.assertTrue(torch.allclose(linear_map.A, self.A))
self.assertTrue(torch.allclose(linear_map.b, self.b))

def test_init_with_numpy_array(self):
A_np = np.array([2.0, 3.0])
b_np = np.array([1.0, 2.0])
linear_map = LinearMap(A_np, b_np, dtype=torch.float32)

self.assertTrue(torch.allclose(linear_map.A, self.A))
self.assertTrue(torch.allclose(linear_map.b, self.b))


if __name__ == "__main__":
unittest.main()
Loading