-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_models.py
36 lines (30 loc) · 1.06 KB
/
test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import pytest
import xarray as xr
import xsimlab as xs
import numpy as np
import logging
from episimlab.models import MarkovToy, NineComptV1, PartitionV1
@pytest.mark.parametrize('model_type', [
# MarkovToy,
# NineComptV1,
PartitionV1
])
def test_model_sanity(model_type):
"""Tests models with a handful of sanity checks."""
model = model_type()
result = model.run_with_defaults()
assert isinstance(result, xr.Dataset)
state = result['compt_model__state']
# ensure that no coords are null
for coord in result.coords.values():
assert not coord.isnull().any()
# ensure that the total population has not changed between
# first and last timepoints
net_change = (state[dict(step=0)] - state[dict(step=-1)]).sum()
assert abs(net_change) <= 1e-8
# ensure that S compt has changed between first and last timesteps
S_init = state[dict(step=0)].loc[dict(compt="S")]
S_final = state[dict(step=-1)].loc[dict(compt="S")]
S_change = (S_final - S_init).sum()
assert abs(S_change) > 1e-8
# model.plot()