Skip to content

Commit

Permalink
test: Adjust tests to changes
Browse files Browse the repository at this point in the history
  • Loading branch information
thorbjoernl committed Apr 26, 2024
1 parent ef93e41 commit aa40ca3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 40 deletions.
44 changes: 22 additions & 22 deletions pyaerocom/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,24 @@ def _get_default_statistic_config() -> dict[str, StatisticsCalculator]:
which calculates all implemented statistics. Can be used as a starting
point for adding additional stats using `dict.update()`
"""
return {
"refdata_mean": lambda x, y, w: np.nanmean(y),
"refdata_std": lambda x, y, w: np.nanstd(y),
"data_mean": lambda x, y, w: np.nanmean(x),
"data_std": lambda x, y, w: np.nanstd(x),
"rms": stat_rms,
"nmb": stat_nmb,
"mnmb": stat_mnmb,
"mb": stat_mb,
"mab": stat_mab,
"fge": stat_fge,
"R": stat_R,
"R_spearman": stat_R_spearman,
"R_kendall": stat_R_kendall,
}


_stats_configuration = _get_default_statistic_config()
return dict(
refdata_mean=lambda x, y, w: np.nanmean(y),
refdata_std=lambda x, y, w: np.nanstd(y),
data_mean=lambda x, y, w: np.nanmean(x),
data_std=lambda x, y, w: np.nanstd(x),
rms=stat_rms,
nmb=stat_nmb,
mnmb=stat_mnmb,
mb=stat_mb,
mab=stat_mab,
fge=stat_fge,
R=stat_R,
R_spearman=stat_R_spearman,
R_kendall=stat_R_kendall,
)


stats_configuration = _get_default_statistic_config()


def register_custom_statistic(name: str, fun: StatisticsCalculator) -> None:
Expand All @@ -155,10 +155,10 @@ def register_custom_statistic(name: str, fun: StatisticsCalculator) -> None:
ValueError:
if name has already been registered, or is otherwise invalid.
"""
if (name in _stats_configuration) or (name in ["totnum", "weighted"]):
raise ValueError(f"Name {name} is already registered in _stats_configuration.")
if (name in stats_configuration) or (name in ["totnum", "weighted"]):
raise ValueError(f"Name {name} is already registered in stats_configuration.")

Check warning on line 159 in pyaerocom/stats/stats.py

View check run for this annotation

Codecov / codecov/patch

pyaerocom/stats/stats.py#L159

Added line #L159 was not covered by tests

_stats_configuration[name] = fun
stats_configuration[name] = fun


def calculate_statistics(
Expand Down Expand Up @@ -246,7 +246,7 @@ def calculate_statistics(
if len(weights) != len(data):
raise ValueError("Invalid input. Length of weights must match length of data.")

statistics = _stats_configuration
statistics = stats_configuration

# Set defaults
data_filters = [FilterNaN()]
Expand Down
37 changes: 19 additions & 18 deletions tests/stats/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import pytest

import pyaerocom.stats.stats
from pyaerocom.stats.implementations import *
from pyaerocom.stats.stats import (
_get_default_statistic_config,
_stats_configuration,
calculate_statistics,
register_custom_statistic,
)
Expand Down Expand Up @@ -52,27 +52,28 @@ def test_calc_stats_keys():
"refdata_std",
]
)
assert len(stats.keys()) == len(expected_keys)

assert set(stats.keys()) == expected_keys

@pytest.fixture
def clean_stats_configuration():
yield
pyaerocom.stats.stats.stats_configuration = _get_default_statistic_config()

class TestCustomStats:
def setup_method(self):
register_custom_statistic("test_statsitic1", lambda x, y, w: 5)
register_custom_statistic("test_statistic2", lambda x, y, w: 3)

def teardown_method(self):
_stats_configuration = _get_default_statistic_config()
@pytest.mark.usefixtures("clean_stats_configuration")
def test_customstats():
register_custom_statistic("test_statistic1", fun=lambda x, y, w: 5)
register_custom_statistic("test_statistic2", fun=lambda x, y, w: 3)

def test_customstats(self):
stats = calculate_statistics(
[1, 2, 3, 4],
[1, 2, 3, 4],
)
assert "test_statistic1" in stats.keys()
assert "test_statistic2" in stats.keys()
assert stats["test_statistic1"] == 5
assert stats["test_statistic2"] == 3
stats = calculate_statistics(
[1, 2, 3, 4],
[1, 2, 3, 4],
)
assert "test_statistic1" in stats.keys()
assert "test_statistic2" in stats.keys()
assert stats["test_statistic1"] == 5
assert stats["test_statistic2"] == 3


perfect_stats_num1_mean1 = {
Expand Down Expand Up @@ -177,7 +178,7 @@ def test_customstats(self):
),
],
)
def test_calc_statistics(data, ref_data, expected):
def test_calc_statistics(data, ref_data, expected: dict):
stats = calculate_statistics(data, ref_data)
assert isinstance(stats, dict)
assert len(stats) == len(expected)
Expand Down

0 comments on commit aa40ca3

Please sign in to comment.