Skip to content

Commit

Permalink
type test functions, update commit hooks incl. mypy, setup.cfg set im…
Browse files Browse the repository at this point in the history
…plicit_optional=true which defaults to false in newer mypys

github.com/python/mypy/pull/13401
  • Loading branch information
janosh committed Dec 24, 2022
1 parent 396bdf4 commit 9303f7d
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 41 deletions.
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.11.4
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 22.8.0
rev: 22.12.0
hooks:
- id: black-jupyter

- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies: [flake8-bugbear]

- repo: https://github.com/asottile/pyupgrade
rev: v2.38.2
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.981
rev: v0.991
hooks:
- id: mypy
additional_dependencies: [types-requests]
Expand All @@ -40,7 +40,7 @@ repos:
- id: format-ipy-cells

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: check-case-conflict
- id: check-symlinks
Expand All @@ -52,14 +52,14 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
rev: v2.2.2
hooks:
- id: codespell
stages: [commit, commit-msg]
exclude_types: [csv, svg, html, yaml, jupyter]

- repo: https://github.com/PyCQA/autoflake
rev: v1.6.1
rev: v2.0.0
hooks:
- id: autoflake

Expand All @@ -70,12 +70,12 @@ repos:
exclude: tests

- repo: https://github.com/PyCQA/docformatter
rev: v1.5.0
rev: v1.5.1
hooks:
- id: docformatter

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.5.2
rev: 1.5.3
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
Expand Down
2 changes: 1 addition & 1 deletion pymatviz/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def hist_elemental_prevalence(
log: bool = False,
keep_top: int = None,
ax: plt.Axes = None,
bar_values: Literal["percent", "count", None] = "percent",
bar_values: Literal["percent", "count"] | None = "percent",
h_offset: int = 0,
v_offset: int = 10,
rotation: int = 45,
Expand Down
6 changes: 3 additions & 3 deletions pymatviz/ptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def ptable_heatmap(
zero_color: str = "#DDD", # light gray
infty_color: str = "lightskyblue",
na_color: str = "white",
heat_mode: Literal["value", "fraction", "percent", None] = "value",
heat_mode: Literal["value", "fraction", "percent"] | None = "value",
precision: str = None,
text_color: str | tuple[str, str] = "auto",
exclude_elements: Sequence[str] = (),
Expand Down Expand Up @@ -367,7 +367,7 @@ def ptable_heatmap_plotly(
count_mode: CountMode = "element_composition",
colorscale: str | Sequence[str] | Sequence[tuple[float, str]] | None = None,
showscale: bool = True,
heat_mode: Literal["value", "fraction", "percent", None] = "value",
heat_mode: Literal["value", "fraction", "percent"] | None = "value",
precision: str = None,
hover_props: Sequence[str] | dict[str, str] | None = None,
hover_data: dict[str, str | int | float] | pd.Series | None = None,
Expand Down Expand Up @@ -418,7 +418,7 @@ def ptable_heatmap_plotly(
hover_data (dict[str, str | int | float] | pd.Series): Map from element symbols
to additional data to display in the hover tooltip. {"Fe": "this shows up in
the hover tooltip on a new line below the element name"}. Defaults to None.
font_colors (list[str]): One or two color strings [min_color, max_color].
font_colors (list[str]): One color name or two for [min_color, max_color].
min_color is applied to annotations for heatmap values
< (max_val - min_val) / 2. Defaults to ["black"].
gap (float): Gap in pixels between tiles of the periodic table. Defaults to 5.
Expand Down
6 changes: 3 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ A toolkit for visualizations in materials informatics.

[![Tests](https://github.com/janosh/pymatviz/actions/workflows/test.yml/badge.svg)](https://github.com/janosh/pymatviz/actions/workflows/test.yml)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/janosh/pymatviz/main.svg)](https://results.pre-commit.ci/latest/github/janosh/pymatviz/main)
[![This project supports Python 3.8+](https://img.shields.io/badge/Python-3.8+-blue.svg?logo=python)](https://python.org/downloads)
[![PyPI](https://img.shields.io/pypi/v/pymatviz?logo=PyPI)](https://pypi.org/project/pymatviz)
[![PyPI Downloads](https://img.shields.io/pypi/dm/pymatviz)](https://pypistats.org/packages/pymatviz)
[![This project supports Python 3.8+](https://img.shields.io/badge/Python-3.8+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
[![PyPI](https://img.shields.io/pypi/v/pymatviz?logo=pypi&logoColor=white)](https://pypi.org/project/pymatviz)
[![PyPI Downloads](https://img.shields.io/pypi/dm/pymatviz?logo=icloud&logoColor=white)](https://pypistats.org/packages/pymatviz)

</h4>

Expand Down
4 changes: 1 addition & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ disallow_untyped_defs = true
warn_redundant_casts = true
warn_unused_ignores = true
show_error_codes = true

[mypy-tests.*]
disallow_untyped_defs = false
implicit_optional = true

[codespell]
ignore-words-list = hist
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_cumulative_error(alpha: float) -> None:
assert isinstance(ax, plt.Axes)


def test_cumulative_residual():
def test_cumulative_residual() -> None:
ax = cumulative_residual(abs(y_true - y_pred))
assert isinstance(ax, plt.Axes)
assert len(ax.lines) == 3
Expand Down
55 changes: 37 additions & 18 deletions tests/test_ptable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Sequence

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
Expand All @@ -17,6 +19,10 @@
from pymatviz.utils import df_ptable


if TYPE_CHECKING:
from pymatviz.ptable import CountMode


@pytest.fixture
def glass_formulas() -> list[str]:
"""Output of:
Expand All @@ -26,10 +32,10 @@ def glass_formulas() -> list[str]:
load_dataset("matbench_glass").composition.head(20)
"""
return (
"Al,Al(NiB)2,Al10Co21B19,Al10Co23B17,Al10Co27B13,Al10Co29B11,Al10Co31B9,"
"Al10Co33B7,Al10Cr3Si7,Al10Fe23B17,Al10Fe27B13,Al10Fe31B9,Al10Fe33B7,"
"Al10Ni23B17,Al10Ni27B13,Al10Ni29B11,Al10Ni31B9,Al10Ni33B7,Al11(CrSi2)3"
).split(",")
"Al Al(NiB)2 Al10Co21B19 Al10Co23B17 Al10Co27B13 Al10Co29B11 Al10Co31B9 "
"Al10Co33B7 Al10Cr3Si7 Al10Fe23B17 Al10Fe27B13 Al10Fe31B9 Al10Fe33B7 "
"Al10Ni23B17 Al10Ni27B13 Al10Ni29B11 Al10Ni31B9 Al10Ni33B7 Al11(CrSi2)3"
).split()


@pytest.fixture
Expand Down Expand Up @@ -64,13 +70,13 @@ def steel_elem_counts(steel_formulas: pd.Series[Composition]) -> pd.Series[int]:
("reduced_composition", {"Fe": 13, "O": 27, "P": 3}),
],
)
def test_count_elements(count_mode, counts):
def test_count_elements(count_mode: CountMode, counts: dict[str, float]) -> None:
series = count_elements(["Fe2 O3"] * 5 + ["Fe4 P4 O16"] * 3, count_mode=count_mode)
expected = pd.Series(counts, index=df_ptable.index, name="count").fillna(0)
assert series.equals(expected)


def test_count_elements_by_atomic_nums():
def test_count_elements_by_atomic_nums() -> None:
series_in = pd.Series(1, index=range(1, 119))
el_cts = count_elements(series_in)
expected = pd.Series(1, index=df_ptable.index, name="count")
Expand All @@ -79,7 +85,7 @@ def test_count_elements_by_atomic_nums():


@pytest.mark.parametrize("range_limits", [(-1, 10), (100, 200)])
def test_count_elements_bad_atomic_nums(range_limits):
def test_count_elements_bad_atomic_nums(range_limits: tuple[int, int]) -> None:
with pytest.raises(ValueError, match="assumed to represent atomic numbers"):
count_elements({idx: 0 for idx in range(*range_limits)})

Expand All @@ -88,7 +94,7 @@ def test_count_elements_bad_atomic_nums(range_limits):
count_elements({str(idx): 0 for idx in range(*range_limits)})


def test_hist_elemental_prevalence(glass_formulas):
def test_hist_elemental_prevalence(glass_formulas: list[str]) -> None:
ax = hist_elemental_prevalence(glass_formulas)
assert isinstance(ax, plt.Axes)

Expand All @@ -99,7 +105,9 @@ def test_hist_elemental_prevalence(glass_formulas):
hist_elemental_prevalence(glass_formulas, keep_top=10, bar_values="count")


def test_ptable_heatmap(glass_formulas, glass_elem_counts):
def test_ptable_heatmap(
glass_formulas: list[str], glass_elem_counts: pd.Series[int]
) -> None:
ax = ptable_heatmap(glass_formulas)
assert isinstance(ax, plt.Axes)

Expand Down Expand Up @@ -139,8 +147,11 @@ def test_ptable_heatmap(glass_formulas, glass_elem_counts):


def test_ptable_heatmap_ratio(
steel_formulas, glass_formulas, steel_elem_counts, glass_elem_counts
):
steel_formulas: list[str],
glass_formulas: list[str],
steel_elem_counts: pd.Series[int],
glass_elem_counts: pd.Series[int],
) -> None:
# composition strings
ax = ptable_heatmap_ratio(glass_formulas, steel_formulas)
assert isinstance(ax, plt.Axes)
Expand All @@ -153,7 +164,7 @@ def test_ptable_heatmap_ratio(
ptable_heatmap_ratio(glass_elem_counts, steel_formulas)


def test_ptable_heatmap_plotly(glass_formulas):
def test_ptable_heatmap_plotly(glass_formulas: list[str]) -> None:
fig = ptable_heatmap_plotly(glass_formulas)
assert isinstance(fig, go.Figure)
assert (
Expand Down Expand Up @@ -194,10 +205,16 @@ def test_ptable_heatmap_plotly(glass_formulas):
)
@pytest.mark.parametrize("showscale", [False, True])
@pytest.mark.parametrize("font_size", [None, 14])
@pytest.mark.parametrize("font_colors", [None, ("black", "white")])
@pytest.mark.parametrize("font_colors", [["red"], ("black", "white")])
def test_ptable_heatmap_plotly_kwarg_combos(
glass_formulas, exclude_elements, heat_mode, showscale, font_size, font_colors, log
):
glass_formulas: list[str],
exclude_elements: Sequence[str],
heat_mode: Literal["value", "fraction", "percent"] | None,
showscale: bool,
font_size: int,
font_colors: tuple[str] | tuple[str, str],
log: bool,
) -> None:
fig = ptable_heatmap_plotly(
glass_formulas,
exclude_elements=exclude_elements,
Expand All @@ -211,8 +228,10 @@ def test_ptable_heatmap_plotly_kwarg_combos(


@pytest.mark.parametrize(
"clr_scl", ["YlGn", ["blue", "red"], [(0, "blue"), (1, "red")]]
"colorscale", ["YlGn", ["blue", "red"], [(0, "blue"), (1, "red")]]
)
def test_ptable_heatmap_plotly_colorscale(glass_formulas, clr_scl):
fig = ptable_heatmap_plotly(glass_formulas, colorscale=clr_scl)
def test_ptable_heatmap_plotly_colorscale(
glass_formulas: list[str], colorscale: str | list[tuple[float, str]] | list[str]
) -> None:
fig = ptable_heatmap_plotly(glass_formulas, colorscale=colorscale)
assert isinstance(fig, go.Figure)
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tests.conftest import y_pred, y_true


def test_add_mae_r2_box():
def test_add_mae_r2_box() -> None:
text_box = add_mae_r2_box(y_pred, y_true)

assert isinstance(text_box, AnchoredText)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_add_identity_line(
assert line["line"]["color"] == line_kwds["color"] if line_kwds else "gray"


def test_df_to_arrays():
def test_df_to_arrays() -> None:
df = pd.DataFrame([y_true, y_pred]).T
x1, y1 = df_to_arrays(None, y_true, y_pred)
x_col, y_col = df.columns[:2]
Expand Down

0 comments on commit 9303f7d

Please sign in to comment.