diff --git a/src/mrsimulator/simulator/sampling_scheme.py b/src/mrsimulator/simulator/sampling_scheme.py index ad9b7088..d3a24df9 100644 --- a/src/mrsimulator/simulator/sampling_scheme.py +++ b/src/mrsimulator/simulator/sampling_scheme.py @@ -23,7 +23,17 @@ def generate_custom_sampling(alpha, beta, weight, triangle_mesh=False): return sampling -def step_averaging(N_alpha: int, N_beta: int, triangle_mesh=True): +def check_triangulation(triangle_mesh: bool, integration_volume: str): + """Check if triangulation can be applied""" + if triangle_mesh and integration_volume != "sphere": + raise NotImplementedError( + "Triangulation of non sphere geometry is not implemented." + ) + + +def step_averaging( + N_alpha: int, N_beta: int, triangle_mesh=True, integration_volume="sphere" +): """Generate STEP averaging samples. Args: @@ -31,15 +41,16 @@ def step_averaging(N_alpha: int, N_beta: int, triangle_mesh=True): beta: number of points along the beta dimension. triangle_mesh: generate ConvexHull triangulation of points. """ - sphereType = "sphere" + check_triangulation(triangle_mesh, integration_volume) + norm_step = 0.0 - if sphereType == "sphere": + if integration_volume == "sphere": a = [1.0, 0.0, 1.0] b = 1.0 - if sphereType == "hemisphere": + if integration_volume == "hemisphere": a = [1.0, 0.0, 1.0] b = 2.0 - if sphereType == "octant": + if integration_volume == "octant": a = [2.0, 1.0, 8.0] b = 2.0 @@ -64,29 +75,30 @@ def step_averaging(N_alpha: int, N_beta: int, triangle_mesh=True): def get_zcw_number(M): """ZCW number""" # returns the number of ZCW angles for the given integer M=2,3,4,... - gM = 5 - gMminus1 = 3 - sum = 5 + g_m = 5 + g_minus1 = 3 + local_sum = 5 for _ in range(M + 1): - sum = gM + gMminus1 - gMminus1 = gM - gM = sum - return sum + local_sum = g_m + g_minus1 + g_minus1 = g_m + g_m = local_sum + return local_sum -def zcw_averaging(M: int, triangle_mesh=True): +def zcw_averaging(M: int, triangle_mesh=True, integration_volume="sphere"): """Generate ZCW averaging samples. Args: M: ZCW point generation factor. triangle_mesh: generate ConvexHull triangulation of points. """ - sphereType = "sphere" - if sphereType == "sphere": + check_triangulation(triangle_mesh, integration_volume) + + if integration_volume == "sphere": c = [1.0, 2.0, 1.0] - if sphereType == "hemisphere": + if integration_volume == "hemisphere": c = [-1.0, 1.0, 1.0] - if sphereType == "octant": + if integration_volume == "octant": c = [-1.0, 1.0, 4.0] N = get_zcw_number(M) @@ -99,11 +111,11 @@ def zcw_averaging(M: int, triangle_mesh=True): return generate_custom_sampling(alpha, beta, weight, triangle_mesh) -if __name__ == "__main__": - sampling = zcw_averaging(M=21) - rad2deg = 180.0 / np.pi - nd_array = np.array( - [sampling.alpha * rad2deg, sampling.beta * rad2deg, sampling.weight] - ).T - size = sampling.alpha.size - np.savetxt(f"zcw{size}.cry", nd_array, header=str(size), fmt="%.6e") +# if __name__ == "__main__": +# sampling = zcw_averaging(M=21) +# rad2deg = 180.0 / np.pi +# nd_array = np.array( +# [sampling.alpha * rad2deg, sampling.beta * rad2deg, sampling.weight] +# ).T +# size = sampling.alpha.size +# np.savetxt(f"zcw{size}.cry", nd_array, header=str(size), fmt="%.6e") diff --git a/tests/spectral_integration_tests/test_custom_sampling_lineshape.py b/tests/spectral_integration_tests/test_custom_sampling_lineshape.py index 5f320d7a..fb5f6466 100644 --- a/tests/spectral_integration_tests/test_custom_sampling_lineshape.py +++ b/tests/spectral_integration_tests/test_custom_sampling_lineshape.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from mrsimulator import Simulator from mrsimulator import Site from mrsimulator import SpinSystem @@ -40,10 +41,12 @@ def test_one_d(): spec_asg = sim.methods[0].simulation.y[0].components[0].real sim.config.custom_sampling = zcw_averaging(M=12) + assert sim.config.get_orientations_count() == 2584 sim.run() spec_zcw = sim.methods[0].simulation.y[0].components[0].real sim.config.custom_sampling = step_averaging(N_alpha=51, N_beta=51) + assert sim.config.get_orientations_count() == 2601 sim.run() spec_step = sim.methods[0].simulation.y[0].components[0].real @@ -115,3 +118,25 @@ def test_internal_external_averaging_spectrum(): np.testing.assert_almost_equal(spec_step_interp, spec_step_bin, decimal=2) np.testing.assert_almost_equal(spec_asg_interp, spec_step_interp, decimal=3) np.testing.assert_almost_equal(spec_asg_interp, spec_zcw_interp, decimal=3) + + +def test_sampling_triangulation(): + error = "Triangulation of non sphere geometry is not implemented." + for vol in ["octant", "hemisphere"]: + with pytest.raises(NotImplementedError, match=error): + _ = step_averaging( + N_alpha=50, N_beta=50, triangle_mesh=True, integration_volume=vol + ) + + with pytest.raises(NotImplementedError, match=error): + _ = zcw_averaging(M=5, triangle_mesh=True, integration_volume=vol) + + +def test_sampling_non_sphere(): + for vol in ["octant", "hemisphere"]: + step_s = step_averaging( + N_alpha=50, N_beta=50, triangle_mesh=False, integration_volume=vol + ) + assert step_s.alpha.size == 2500 + zcw_s = zcw_averaging(M=5, triangle_mesh=False, integration_volume=vol) + assert zcw_s.alpha.size == 89 diff --git a/tests/spectral_integration_tests/test_lineshapes.py b/tests/spectral_integration_tests/test_lineshapes.py index 35b204c5..cfd5b836 100644 --- a/tests/spectral_integration_tests/test_lineshapes.py +++ b/tests/spectral_integration_tests/test_lineshapes.py @@ -1,4 +1,6 @@ """Lineshape Test.""" +from copy import deepcopy +from os import mkdir from os import path from pprint import pformat @@ -43,6 +45,9 @@ def report_pdf(report): def compile_plots(dim, rep, info, title=None, report=None, label="simpson"): """Test report plot""" + if not __GENERATE_REPORT__: + return + _, ax = plt.subplots(1, 2, figsize=(9, 4), gridspec_kw={"width_ratios": [1, 1]}) for i, res in enumerate(rep[:1]): @@ -79,6 +84,27 @@ def compile_plots(dim, rep, info, title=None, report=None, label="simpson"): plt.close() +def test_pdf(): + global __GENERATE_REPORT__ + temp_status = deepcopy(__GENERATE_REPORT__) + __GENERATE_REPORT__ = True + + is_present = path.isdir("_temp") + if not is_present: + mkdir("_temp") + filename = "_temp/lineshapes_report_scrap.pdf" + report_file = PdfPages(filename) + dim = np.arange(10) + res = [np.arange(10), np.arange(10)] + info = {"config": 1, "spin_systems": 2, "methods": 3} + compile_plots(dim, [res], info, title="Shielding Sidebands", report=report_file) + report_file.close() + is_file = path.isfile(filename) + assert is_file + + __GENERATE_REPORT__ = temp_status + + def check_all_close(res, message, rel_limit): """Check if the vectos in res are all close within relative limits""" for item in res: @@ -110,8 +136,7 @@ def test_pure_shielding_sideband_simpson(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots(dim, res, info, title="Shielding Sidebands", report=report) + compile_plots(dim, res, info, title="Shielding Sidebands", report=report) message = f"{error_message} test0{i}.json" check_all_close(res, message, rel_limit=2.5) @@ -137,8 +162,7 @@ def test_pure_quadrupolar_sidebands_simpson(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots(dim, res, info, title="Quad Sidebands", report=report) + compile_plots(dim, res, info, title="Quad Sidebands", report=report) message = f"{error_message} test0{i:02d}.json" check_all_close(res, message, rel_limit=1.5) @@ -150,7 +174,6 @@ def test_csa_plus_quadrupolar_lineshape_simpson(report): ) path_ = path.join(SIMPSON_TEST_PATH, "csa_quad") for i in range(10): - print(i) filename = path.join(path_, f"test{i:02d}", f"test{i:02d}.json") res = [] for volume in VOLUMES: @@ -159,14 +182,13 @@ def test_csa_plus_quadrupolar_lineshape_simpson(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots( - dim, - res, - info, - title="Quad + Shielding Sidebands", - report=report, - ) + compile_plots( + dim, + res, + info, + title="Quad + Shielding Sidebands", + report=report, + ) message = f"{error_message} test0{i:02d}.json" check_all_close(res, message, rel_limit=0.9) @@ -186,14 +208,13 @@ def test_1st_order_quadrupolar_lineshape_simpson(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots( - dim, - res, - info, - title="1st Order Quadrupolar Lineshape", - report=report, - ) + compile_plots( + dim, + res, + info, + title="1st Order Quadrupolar Lineshape", + report=report, + ) message = f"{error_message} test0{i:02d}.json" check_all_close(res, message, rel_limit=1.0) @@ -211,8 +232,7 @@ def test_j_coupling_lineshape_simpson(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots(dim, res, info, title="J-coupling Spectra", report=report) + compile_plots(dim, res, info, title="J-coupling Spectra", report=report) message = f"{error_message} test0{i:02d}.json" check_all_close(res, message, rel_limit=1.1) @@ -232,10 +252,7 @@ def test_dipolar_coupling_lineshape_simpson(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots( - dim, res, info, title="Dipolar-coupling Spectra", report=report - ) + compile_plots(dim, res, info, title="Dipolar-coupling Spectra", report=report) message = f"{error_message} test0{i:02d}.json" check_all_close(res, message, rel_limit=1.5) @@ -262,15 +279,14 @@ def test_quad_csa_cross_rmnsim(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots( - dim, - res, - info, - title="Quad-CSA 2nd Order Cross-Term", - report=report, - label="rmnsim", - ) + compile_plots( + dim, + res, + info, + title="Quad-CSA 2nd Order Cross-Term", + report=report, + label="rmnsim", + ) message = f"{error_message} test0{i}.json" check_all_close(res, message, rel_limit=0.3) @@ -301,15 +317,14 @@ def test_pure_shielding_static_lineshape_python_brute(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots( - dim, - res, - info, - title="Shielding Static Lineshape (Brute Force)", - report=report, - label="Brute", - ) + compile_plots( + dim, + res, + info, + title="Shielding Static Lineshape (Brute Force)", + report=report, + label="Brute", + ) message = f"{error_message} test0{i}.json" check_all_close(res, message, rel_limit=2) @@ -338,15 +353,14 @@ def test_pure_quadrupolar_lineshape_python_brute(report): ) res.append([data_mrsimulator, data_source]) - if __GENERATE_REPORT__: - compile_plots( - dim, - res, - info, - title="Quad Lineshape Self-Test", - report=report, - label="self", - ) + compile_plots( + dim, + res, + info, + title="Quad Lineshape Self-Test", + report=report, + label="self", + ) message = f"{error_message} test0{i:02d}.json" check_all_close(res, message, rel_limit=1.5) diff --git a/tests/spectral_integration_tests/utils.py b/tests/spectral_integration_tests/utils.py index 66aca19a..6f57f57e 100644 --- a/tests/spectral_integration_tests/utils.py +++ b/tests/spectral_integration_tests/utils.py @@ -67,8 +67,6 @@ def get_data(filename): def get_csdm_data(source_file, quantity): """Get data from csdm object""" csdm = cp.load(source_file) - if quantity == "time": - csdm = csdm.fft() data_source = csdm.y[0].components[0] return data_source diff --git a/tests/test_orientations_and_triangle_interpolation.py b/tests/test_orientations_and_triangle_interpolation.py index 7538e296..4ea7818b 100644 --- a/tests/test_orientations_and_triangle_interpolation.py +++ b/tests/test_orientations_and_triangle_interpolation.py @@ -1,4 +1,8 @@ """Test for c functions.""" +from copy import deepcopy +from os import mkdir +from os import path + import matplotlib.pyplot as plt import mrsimulator.tests.tests as clib import numpy as np @@ -109,6 +113,30 @@ def plot_2d_raster(rep, title=None, report=None): plt.close() +def test_pdf(): + global __GENERATE_REPORT__ + temp_status = deepcopy(__GENERATE_REPORT__) + __GENERATE_REPORT__ = True + + is_present = path.isdir("_temp") + if not is_present: + mkdir("_temp") + filename = "_temp/interpolation_report_scrap.pdf" + report_file = PdfPages(filename) + amp2d = np.random.rand(100).reshape(10, 10) + pts1 = np.array([1, 3, 6]) + pts2 = np.array([5, 2, 9]) + proj_x = amp2d.sum(axis=0) + proj_y = amp2d.sum(axis=1) + rep = [amp2d, pts1, pts2, proj_x, proj_y] + plot_2d_raster([rep], title="All points within FOV", report=report_file) + report_file.close() + is_file = path.isfile(filename) + assert is_file + + __GENERATE_REPORT__ = temp_status + + def test_octahedron_averaging_setup(): nt = 64 cos_alpha_py, cos_beta_py, amp_py = cosine_of_polar_angles_and_amplitudes(nt) @@ -309,8 +337,7 @@ def test_triangle_rasterization1(report): rep.append([amp1, lst1, lst2, amp3, amp2]) - if __GENERATE_REPORT__: - plot_2d_raster(rep, title="All points within FOV", report=report) + plot_2d_raster(rep, title="All points within FOV", report=report) def test_triangle_rasterization2(report): @@ -339,10 +366,7 @@ def test_triangle_rasterization2(report): rep.append([amp1, lst1, lst2, None, amp2]) - if __GENERATE_REPORT__: - plot_2d_raster( - rep, title="One or more points outside of FOV", report=report - ) + plot_2d_raster(rep, title="One or more points outside of FOV", report=report) def test_triangle_rasterization3(report): @@ -362,10 +386,7 @@ def test_triangle_rasterization3(report): assert np.allclose(amp_x, amp1.sum(axis=0), atol=1e-2) rep.append([amp1, lst1, lst2, amp_x, None]) - if __GENERATE_REPORT__: - plot_2d_raster( - rep, title="One or more points outside of FOV", report=report - ) + plot_2d_raster(rep, title="One or more points outside of FOV", report=report) def test_triangle_rasterization4(report): @@ -385,9 +406,8 @@ def test_triangle_rasterization4(report): amp1, amp_y, amp_x, lst1, lst2 = get_amps_from_interpolation(list_, scl) rep.append([amp1, lst1, lst2, amp_x, None]) - if __GENERATE_REPORT__: - plot_2d_raster(rep, title="All points outside of FOV", report=report) - # assert np.allclose(amp_x.real, amp1.real.sum(axis=0), atol=1e-15) + plot_2d_raster(rep, title="All points outside of FOV", report=report) + # assert np.allclose(amp_x.real, amp1.real.sum(axis=0), atol=1e-15) def test_triangle_rasterization5(report): @@ -406,8 +426,7 @@ def test_triangle_rasterization5(report): assert np.allclose(amp2, amp1.sum(axis=1), atol=1e-15) rep.append([amp1, lst1, lst2, amp3, amp2]) - if __GENERATE_REPORT__: - plot_2d_raster(rep, title="Small trianges spanning 1-2 bins", report=report) + plot_2d_raster(rep, title="Small trianges spanning 1-2 bins", report=report) def get_amps_from_interpolation(list_, scl):