-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_end_to_end.py
76 lines (55 loc) · 2.06 KB
/
test_end_to_end.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""An end to end example and test."""
from typing import Callable
import pickle
from sklearn import datasets, svm
from sklearn.model_selection import cross_validate
import kotsu
model_registry = kotsu.registration.ModelRegistry()
model_registry.register(
id="SVC-v1",
entry_point=svm.SVC,
kwargs={"kernel": "linear", "C": 1, "random_state": 1},
)
model_registry.register(
id="SVC-v2",
entry_point=svm.SVC,
kwargs={"kernel": "linear", "C": 0.5, "random_state": 1},
)
validation_registry = kotsu.registration.ValidationRegistry()
def factory_iris_cross_val(folds: int) -> Callable:
"""Factory for iris cross validation."""
def iris_cross_val(model, validation_artefacts_dir=None, model_artefacts_dir=None) -> dict:
"""Iris classification cross validation."""
X, y = datasets.load_iris(return_X_y=True)
scores = cross_validate(model, X, y, cv=folds, return_estimator=True)
if model_artefacts_dir:
# Save the trained models from each fold
for fold_idx, model in enumerate(scores["estimator"]):
with open(model_artefacts_dir + f"model_from_fold_{fold_idx}.pk", "wb") as f:
pickle.dump(model, f)
results = {f"fold_{i}_score": score for i, score in enumerate(scores["test_score"])}
results["mean_score"] = scores["test_score"].mean()
results["std_score"] = scores["test_score"].std()
return results
return iris_cross_val
validation_registry.register(
id="iris_cross_val-v1",
entry_point=factory_iris_cross_val,
kwargs={"folds": 5},
)
validation_registry.register(
id="iris_cross_val-v2",
entry_point=factory_iris_cross_val,
kwargs={"folds": 10},
)
def test_run(tmpdir):
kotsu.run.run(
model_registry, validation_registry, results_path=str(tmpdir) + "/validation_results.csv"
)
def test_run_save_models(tmpdir):
kotsu.run.run(
model_registry,
validation_registry,
results_path=str(tmpdir) + "/validation_results.csv",
artefacts_store_dir=str(tmpdir),
)