Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code unit test coverage #336

Merged
merged 4 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
62 changes: 37 additions & 25 deletions src/mrsimulator/simulator/sampling_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,34 @@ def generate_custom_sampling(alpha, beta, weight, triangle_mesh=False):
return sampling


def step_averaging(N_alpha: int, N_beta: int, triangle_mesh=True):
def check_triangulation(triangle_mesh: bool, integration_volume: str):
"""Check if triangulation can be applied"""
if triangle_mesh and integration_volume != "sphere":
raise NotImplementedError(
"Triangulation of non sphere geometry is not implemented."
)


def step_averaging(
N_alpha: int, N_beta: int, triangle_mesh=True, integration_volume="sphere"
):
"""Generate STEP averaging samples.

Args:
N_alpha: number of points along the alpha dimension.
beta: number of points along the beta dimension.
triangle_mesh: generate ConvexHull triangulation of points.
"""
sphereType = "sphere"
check_triangulation(triangle_mesh, integration_volume)

norm_step = 0.0
if sphereType == "sphere":
if integration_volume == "sphere":
a = [1.0, 0.0, 1.0]
b = 1.0
if sphereType == "hemisphere":
if integration_volume == "hemisphere":
a = [1.0, 0.0, 1.0]
b = 2.0
if sphereType == "octant":
if integration_volume == "octant":
a = [2.0, 1.0, 8.0]
b = 2.0

Expand All @@ -64,29 +75,30 @@ def step_averaging(N_alpha: int, N_beta: int, triangle_mesh=True):
def get_zcw_number(M):
"""ZCW number"""
# returns the number of ZCW angles for the given integer M=2,3,4,...
gM = 5
gMminus1 = 3
sum = 5
g_m = 5
g_minus1 = 3
local_sum = 5
for _ in range(M + 1):
sum = gM + gMminus1
gMminus1 = gM
gM = sum
return sum
local_sum = g_m + g_minus1
g_minus1 = g_m
g_m = local_sum
return local_sum


def zcw_averaging(M: int, triangle_mesh=True):
def zcw_averaging(M: int, triangle_mesh=True, integration_volume="sphere"):
"""Generate ZCW averaging samples.

Args:
M: ZCW point generation factor.
triangle_mesh: generate ConvexHull triangulation of points.
"""
sphereType = "sphere"
if sphereType == "sphere":
check_triangulation(triangle_mesh, integration_volume)

if integration_volume == "sphere":
c = [1.0, 2.0, 1.0]
if sphereType == "hemisphere":
if integration_volume == "hemisphere":
c = [-1.0, 1.0, 1.0]
if sphereType == "octant":
if integration_volume == "octant":
c = [-1.0, 1.0, 4.0]

N = get_zcw_number(M)
Expand All @@ -99,11 +111,11 @@ def zcw_averaging(M: int, triangle_mesh=True):
return generate_custom_sampling(alpha, beta, weight, triangle_mesh)


if __name__ == "__main__":
sampling = zcw_averaging(M=21)
rad2deg = 180.0 / np.pi
nd_array = np.array(
[sampling.alpha * rad2deg, sampling.beta * rad2deg, sampling.weight]
).T
size = sampling.alpha.size
np.savetxt(f"zcw{size}.cry", nd_array, header=str(size), fmt="%.6e")
# if __name__ == "__main__":
# sampling = zcw_averaging(M=21)
# rad2deg = 180.0 / np.pi
# nd_array = np.array(
# [sampling.alpha * rad2deg, sampling.beta * rad2deg, sampling.weight]
# ).T
# size = sampling.alpha.size
# np.savetxt(f"zcw{size}.cry", nd_array, header=str(size), fmt="%.6e")
25 changes: 25 additions & 0 deletions tests/spectral_integration_tests/test_custom_sampling_lineshape.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
from mrsimulator import Simulator
from mrsimulator import Site
from mrsimulator import SpinSystem
Expand Down Expand Up @@ -40,10 +41,12 @@ def test_one_d():
spec_asg = sim.methods[0].simulation.y[0].components[0].real

sim.config.custom_sampling = zcw_averaging(M=12)
assert sim.config.get_orientations_count() == 2584
sim.run()
spec_zcw = sim.methods[0].simulation.y[0].components[0].real

sim.config.custom_sampling = step_averaging(N_alpha=51, N_beta=51)
assert sim.config.get_orientations_count() == 2601
sim.run()
spec_step = sim.methods[0].simulation.y[0].components[0].real

Expand Down Expand Up @@ -115,3 +118,25 @@ def test_internal_external_averaging_spectrum():
np.testing.assert_almost_equal(spec_step_interp, spec_step_bin, decimal=2)
np.testing.assert_almost_equal(spec_asg_interp, spec_step_interp, decimal=3)
np.testing.assert_almost_equal(spec_asg_interp, spec_zcw_interp, decimal=3)


def test_sampling_triangulation():
error = "Triangulation of non sphere geometry is not implemented."
for vol in ["octant", "hemisphere"]:
with pytest.raises(NotImplementedError, match=error):
_ = step_averaging(
N_alpha=50, N_beta=50, triangle_mesh=True, integration_volume=vol
)

with pytest.raises(NotImplementedError, match=error):
_ = zcw_averaging(M=5, triangle_mesh=True, integration_volume=vol)


def test_sampling_non_sphere():
for vol in ["octant", "hemisphere"]:
step_s = step_averaging(
N_alpha=50, N_beta=50, triangle_mesh=False, integration_volume=vol
)
assert step_s.alpha.size == 2500
zcw_s = zcw_averaging(M=5, triangle_mesh=False, integration_volume=vol)
assert zcw_s.alpha.size == 89
122 changes: 68 additions & 54 deletions tests/spectral_integration_tests/test_lineshapes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Lineshape Test."""
from copy import deepcopy
from os import mkdir
from os import path
from pprint import pformat

Expand Down Expand Up @@ -43,6 +45,9 @@ def report_pdf(report):

def compile_plots(dim, rep, info, title=None, report=None, label="simpson"):
"""Test report plot"""
if not __GENERATE_REPORT__:
return

_, ax = plt.subplots(1, 2, figsize=(9, 4), gridspec_kw={"width_ratios": [1, 1]})

for i, res in enumerate(rep[:1]):
Expand Down Expand Up @@ -79,6 +84,27 @@ def compile_plots(dim, rep, info, title=None, report=None, label="simpson"):
plt.close()


def test_pdf():
global __GENERATE_REPORT__
temp_status = deepcopy(__GENERATE_REPORT__)
__GENERATE_REPORT__ = True

is_present = path.isdir("_temp")
if not is_present:
mkdir("_temp")
filename = "_temp/lineshapes_report_scrap.pdf"
report_file = PdfPages(filename)
dim = np.arange(10)
res = [np.arange(10), np.arange(10)]
info = {"config": 1, "spin_systems": 2, "methods": 3}
compile_plots(dim, [res], info, title="Shielding Sidebands", report=report_file)
report_file.close()
is_file = path.isfile(filename)
assert is_file

__GENERATE_REPORT__ = temp_status


def check_all_close(res, message, rel_limit):
"""Check if the vectos in res are all close within relative limits"""
for item in res:
Expand Down Expand Up @@ -110,8 +136,7 @@ def test_pure_shielding_sideband_simpson(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(dim, res, info, title="Shielding Sidebands", report=report)
compile_plots(dim, res, info, title="Shielding Sidebands", report=report)

message = f"{error_message} test0{i}.json"
check_all_close(res, message, rel_limit=2.5)
Expand All @@ -137,8 +162,7 @@ def test_pure_quadrupolar_sidebands_simpson(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(dim, res, info, title="Quad Sidebands", report=report)
compile_plots(dim, res, info, title="Quad Sidebands", report=report)

message = f"{error_message} test0{i:02d}.json"
check_all_close(res, message, rel_limit=1.5)
Expand All @@ -150,7 +174,6 @@ def test_csa_plus_quadrupolar_lineshape_simpson(report):
)
path_ = path.join(SIMPSON_TEST_PATH, "csa_quad")
for i in range(10):
print(i)
filename = path.join(path_, f"test{i:02d}", f"test{i:02d}.json")
res = []
for volume in VOLUMES:
Expand All @@ -159,14 +182,13 @@ def test_csa_plus_quadrupolar_lineshape_simpson(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(
dim,
res,
info,
title="Quad + Shielding Sidebands",
report=report,
)
compile_plots(
dim,
res,
info,
title="Quad + Shielding Sidebands",
report=report,
)

message = f"{error_message} test0{i:02d}.json"
check_all_close(res, message, rel_limit=0.9)
Expand All @@ -186,14 +208,13 @@ def test_1st_order_quadrupolar_lineshape_simpson(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(
dim,
res,
info,
title="1st Order Quadrupolar Lineshape",
report=report,
)
compile_plots(
dim,
res,
info,
title="1st Order Quadrupolar Lineshape",
report=report,
)

message = f"{error_message} test0{i:02d}.json"
check_all_close(res, message, rel_limit=1.0)
Expand All @@ -211,8 +232,7 @@ def test_j_coupling_lineshape_simpson(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(dim, res, info, title="J-coupling Spectra", report=report)
compile_plots(dim, res, info, title="J-coupling Spectra", report=report)

message = f"{error_message} test0{i:02d}.json"
check_all_close(res, message, rel_limit=1.1)
Expand All @@ -232,10 +252,7 @@ def test_dipolar_coupling_lineshape_simpson(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(
dim, res, info, title="Dipolar-coupling Spectra", report=report
)
compile_plots(dim, res, info, title="Dipolar-coupling Spectra", report=report)

message = f"{error_message} test0{i:02d}.json"
check_all_close(res, message, rel_limit=1.5)
Expand All @@ -262,15 +279,14 @@ def test_quad_csa_cross_rmnsim(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(
dim,
res,
info,
title="Quad-CSA 2nd Order Cross-Term",
report=report,
label="rmnsim",
)
compile_plots(
dim,
res,
info,
title="Quad-CSA 2nd Order Cross-Term",
report=report,
label="rmnsim",
)

message = f"{error_message} test0{i}.json"
check_all_close(res, message, rel_limit=0.3)
Expand Down Expand Up @@ -301,15 +317,14 @@ def test_pure_shielding_static_lineshape_python_brute(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(
dim,
res,
info,
title="Shielding Static Lineshape (Brute Force)",
report=report,
label="Brute",
)
compile_plots(
dim,
res,
info,
title="Shielding Static Lineshape (Brute Force)",
report=report,
label="Brute",
)

message = f"{error_message} test0{i}.json"
check_all_close(res, message, rel_limit=2)
Expand Down Expand Up @@ -338,15 +353,14 @@ def test_pure_quadrupolar_lineshape_python_brute(report):
)
res.append([data_mrsimulator, data_source])

if __GENERATE_REPORT__:
compile_plots(
dim,
res,
info,
title="Quad Lineshape Self-Test",
report=report,
label="self",
)
compile_plots(
dim,
res,
info,
title="Quad Lineshape Self-Test",
report=report,
label="self",
)

message = f"{error_message} test0{i:02d}.json"
check_all_close(res, message, rel_limit=1.5)
2 changes: 0 additions & 2 deletions tests/spectral_integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def get_data(filename):
def get_csdm_data(source_file, quantity):
"""Get data from csdm object"""
csdm = cp.load(source_file)
if quantity == "time":
csdm = csdm.fft()
data_source = csdm.y[0].components[0]
return data_source

Expand Down