Skip to content

Commit

Permalink
fix get_crystal_sys raise ValueError on non-positive space group numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Feb 6, 2022
1 parent 16c92ab commit 9a535f7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 10 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ jobs:
uses: actions/checkout@v2

- name: Install dependencies
run: |
pip install .[test]
run: pip install .[test]

- name: Run pytest
# Only publish to PyPI if tests pass.
Expand Down
4 changes: 2 additions & 2 deletions ml_matrics/sunburst.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import plotly.express as px
from plotly.graph_objects import Figure

from ml_matrics.utils import get_crystal_system
from ml_matrics.utils import get_crystal_sys


def spacegroup_sunburst(
Expand Down Expand Up @@ -38,7 +38,7 @@ def spacegroup_sunburst(
series = pd.Series(spacegroups)

df = pd.DataFrame({"spacegroup": range(230)})
df["cryst_sys"] = [get_crystal_system(spg) for spg in range(230)]
df["cryst_sys"] = [get_crystal_sys(spg) for spg in range(1, 231)]

df["values"] = series.value_counts().reindex(range(230), fill_value=0)

Expand Down
14 changes: 8 additions & 6 deletions ml_matrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def add_mae_r2_box(
loc: str = "lower right",
prec: int = 3,
**kwargs: Any,
) -> None:
) -> AnchoredText:
"""Provide a set of x and y values of equal length and an optional Axes object
on which to print the values' mean absolute error and R^2 coefficient of
determination.
Expand All @@ -132,9 +132,14 @@ def add_mae_r2_box(
text_box = AnchoredText(mae_str + r2_str, loc=loc, frameon=frameon, **kwargs)
ax.add_artist(text_box)

return text_box


def get_crystal_system(spg: int) -> str:
def get_crystal_sys(spg: int) -> str:
"""Get the crystal system for an international space group number."""
if spg < 1 or spg > 230:
raise ValueError(f"Received invalid space group {spg}")

if 0 < spg < 3:
return "triclinic"
if spg < 16:
Expand All @@ -147,7 +152,4 @@ def get_crystal_system(spg: int) -> str:
return "trigonal"
if spg < 195:
return "hexagonal"
if spg < 231:
return "cubic"
else:
raise ValueError(f"Received invalid space group {spg}")
return "cubic"
37 changes: 37 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
from matplotlib.offsetbox import AnchoredText

from ml_matrics.utils import add_mae_r2_box, get_crystal_sys

from . import y_pred, y_true


def test_add_mae_r2_box():

text_box = add_mae_r2_box(y_pred, y_true)

assert isinstance(text_box, AnchoredText)

assert text_box.txt.get_text() == "$\\mathrm{MAE} = 0.116$\n$R^2 = 0.740$"


@pytest.mark.parametrize(
"input, expected",
[
(1, "triclinic"),
(15, "monoclinic"),
(16, "orthorhombic"),
(75, "tetragonal"),
(143, "trigonal"),
(168, "hexagonal"),
(230, "cubic"),
],
)
def test_get_crystal_sys(input, expected):
assert expected == get_crystal_sys(input)


@pytest.mark.parametrize("spg", [-1, 0, 231])
def test_get_crystal_sys_invalid(spg):
with pytest.raises(ValueError, match=f"Received invalid space group {spg}"):
get_crystal_sys(spg)

0 comments on commit 9a535f7

Please sign in to comment.