forked from MESMER-group/mesmer
/
test_auto_regression.py
92 lines (66 loc) · 2.71 KB
/
test_auto_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from unittest import mock
import numpy as np
import pytest
import xarray as xr
import mesmer.core.auto_regression
from mesmer.core.utils import _check_dataarray_form, _check_dataset_form
from .utils import trend_data_1D, trend_data_2D
@pytest.mark.parametrize("obj", [xr.Dataset(), None])
def test_auto_regression_xr_errors(obj):
with pytest.raises(TypeError, match="Expected a `xr.DataArray`"):
mesmer.core.auto_regression._fit_auto_regression_xr(obj, "dim", lags=1)
@pytest.mark.parametrize("lags", [1, 2])
def test_auto_regression_xr_1D(lags):
data = trend_data_1D()
res = mesmer.core.auto_regression._fit_auto_regression_xr(data, "time", lags=lags)
_check_dataset_form(
res,
"_fit_auto_regression_result",
required_vars=["trend", "coeffs", "standard_deviation"],
)
_check_dataarray_form(res.trend, "trend", ndim=0, shape=())
_check_dataarray_form(
res.coeffs, "coeffs", ndim=1, required_dims={"lags"}, shape=(lags,)
)
_check_dataarray_form(
res.standard_deviation, "standard_deviation", ndim=0, shape=()
)
@pytest.mark.parametrize("lags", [1, 2])
def test_auto_regression_xr_2D(lags):
data = trend_data_2D()
res = mesmer.core.auto_regression._fit_auto_regression_xr(data, "time", lags=lags)
(n_cells,) = data.cells.shape
_check_dataset_form(
res,
"_fit_auto_regression_result",
required_vars=["trend", "coeffs", "standard_deviation"],
)
_check_dataarray_form(res.trend, "trend", ndim=1, shape=(n_cells,))
_check_dataarray_form(
res.coeffs,
"coeffs",
ndim=2,
required_dims={"cells", "lags"},
shape=(n_cells, lags),
)
_check_dataarray_form(
res.standard_deviation, "standard_deviation", ndim=1, shape=(n_cells,)
)
@pytest.mark.parametrize("lags", [1, 2])
def test_auto_regression_np(lags):
data = np.array([0, 1, 3.14])
mock_auto_regressor = mock.Mock()
mock_auto_regressor.params = np.array([0.1, 0.25])
mock_auto_regressor.sigma2 = 3.14
with mock.patch(
"statsmodels.tsa.ar_model.AutoReg"
) as mocked_auto_regression, mock.patch(
"statsmodels.tsa.ar_model.AutoRegResults"
) as mocked_auto_regression_result:
mocked_auto_regression.return_value = mocked_auto_regression_result
mocked_auto_regression_result.return_value = mock_auto_regressor
mesmer.core.auto_regression._fit_auto_regression_np(data, lags=lags)
mocked_auto_regression.assert_called_once()
mocked_auto_regression.assert_called_with(data, lags=lags, old_names=False)
mocked_auto_regression_result.fit.assert_called_once()
mocked_auto_regression_result.fit.assert_called_with()