/
test_all_models.py
89 lines (71 loc) · 2.57 KB
/
test_all_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Test all models
1. [x] list all models to test (for model-groups, take the first one)
2. Write functions to test a single models
- Create a new conda environment
- activate the environment
- run `kipoi test model --source=kipoi` in that environment
"""
import pytest
import subprocess
import kipoi
import logging
from kipoi_conda import get_kipoi_bin, env_exists, remove_env
from kipoi.cli.env import conda_env_name
from kipoi_utils.utils import list_files_recursively, read_txt
def models_to_test(src):
"""Returns a list of models to test
By default, this method returns all the model. In case a model group has a `test_subset.txt`
file present in the group directory, then testing is only performed for models
listed in `test_subset.txt`.
Args:
src: Model source
"""
import os
txt_files = list_files_recursively(src.local_path, "test_subset", "txt")
exclude = []
include = []
for x in txt_files:
d = os.path.dirname(x)
exclude += [d]
include += [os.path.join(d, l) for l in read_txt(os.path.join(src.local_path, x))]
# try to load every model
for m in include:
src.get_model_descr(m)
models = src.list_models().model
for excl in exclude:
models = models[~models.str.startswith(excl)]
return list(models) + include
@pytest.mark.parametrize("model_name", models_to_test(kipoi.get_source("kipoi")))
def test_model(model_name, caplog):
"""kipoi test ...
"""
caplog.set_level(logging.INFO)
source_name = "kipoi"
assert source_name == "kipoi"
env_name = conda_env_name(model_name, model_name, source_name)
env_name = "test-" + env_name # prepend "test-"
# if environment already exists, remove it
if env_exists(env_name):
print("Removing the environment: {0}".format(env_name))
remove_env(env_name)
# create the model test environment
args = ["kipoi", "env", "create",
"--source", source_name,
"--env", env_name,
model_name]
returncode = subprocess.call(args=args)
assert returncode == 0
if model_name == "basenji":
batch_size = str(2)
else:
batch_size = str(4)
# run the tests in the environment
args = [get_kipoi_bin(env_name), "test",
"--batch_size", batch_size,
"--source", source_name,
model_name]
returncode = subprocess.call(args=args)
assert returncode == 0
for record in caplog.records:
# there shoudn't be any warning
assert record.levelname not in ['WARN', 'WARNING', 'ERROR', 'CRITICAL']