/
model_parameters.py
executable file
·144 lines (122 loc) · 5.43 KB
/
model_parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from typing import Union
from django.db.models.base import Model
import paramtools as pt
from webapp.apps.comp.models import ModelConfig
from webapp.apps.comp.compute import Compute, SyncCompute, JobFailError
from webapp.apps.comp import actions
from webapp.apps.comp.exceptions import AppError, NotReady, Stale
import os
import json
INPUTS = os.path.join(os.path.abspath(os.path.dirname(__file__)), "inputs.json")
def pt_factory(classname, defaults):
return type(classname, (pt.Parameters,), {"defaults": defaults})
class ModelParameters:
"""
Handles logic for getting cached model parameters and updating the cache.
"""
def __init__(self, project: "Project", compute: Union[SyncCompute, Compute] = None):
self.project = project
print(self.project)
if compute is not None:
self.comptue = compute
elif self.project.cluster.version == "v0":
self.compute = compute or SyncCompute()
else:
self.compute = compute or Compute()
self.config = None
def defaults(self, init_meta_parameters=None):
# get Parameters class for meta parameters and adjust its values.
meta_param_parser = self.meta_parameters_parser()
meta_param_parser.adjust(init_meta_parameters or {})
meta_parameters = meta_param_parser.dump()
return {
"model_parameters": self.model_parameters_parser(
meta_param_parser.specification(meta_data=False, serializable=True)
),
"meta_parameters": meta_parameters,
}
def meta_parameters_parser(self) -> pt.Parameters:
res = self.get_inputs()
params = pt_factory("MetaParametersParser", res["meta_parameters"])()
# params._defer_validation = True
return params
def model_parameters_parser(self, meta_parameters_values=None):
res = self.get_inputs(meta_parameters_values)
# TODO: just return defaults or return the parsers, too?
# model_parameters_parser = {}
# for sect, defaults in res["model_parameters"]:
# model_parameters_parser[sect] = type(
# "Parser", (pt.Parameters), {"defaults": defaults},
# )()
# return model_parameters_parser
return res["model_parameters"]
def cleanup_meta_parameters(self, meta_parameters_values, meta_parameters):
# clean up meta parameters before saving them.
if not meta_parameters_values:
return {}
mp = pt_factory("MP", meta_parameters)()
mp.adjust(meta_parameters_values)
return mp.specification(meta_data=False, serializable=True)
def get_inputs(self, meta_parameters_values=None):
"""
Get cached version of inputs or retrieve new version.
"""
meta_parameters_values = meta_parameters_values or {}
self.config = None
try:
self.config = ModelConfig.objects.get(
project=self.project,
model_version=str(self.project.latest_tag),
meta_parameters_values=meta_parameters_values,
)
print("model config status", self.config.status)
if self.config.status != "SUCCESS": # and not self.config.is_stale():
raise NotReady(self.config)
# elif self.config.status != "SUCCESS" and self.config.is_stale():
# raise Stale(self.config)
except (ModelConfig.DoesNotExist, Stale) as e:
response = self.compute.submit_job(
project=self.project,
task_name=actions.INPUTS,
task_kwargs={"meta_param_dict": meta_parameters_values or {}},
path_prefix="/api/v1/jobs"
if self.project.cluster.version == "v1"
else "",
)
if self.project.cluster.version == "v1" and isinstance(
e, ModelConfig.DoesNotExist
):
self.config = ModelConfig.objects.create(
project=self.project,
model_version=str(self.project.latest_tag),
meta_parameters_values=meta_parameters_values,
inputs_version="v1",
job_id=response,
status="PENDING",
)
raise NotReady(self.config)
elif self.project.cluster.version == "v1" and isinstance(e, Stale):
self.config.model_version = str(self.project.latest_tag)
self.config.job_id = response
self.config.status = "PENDING"
self.config.save()
raise NotReady(self.config)
success, result = response
if not success:
raise AppError(meta_parameters_values, result["traceback"])
save_vals = self.cleanup_meta_parameters(
meta_parameters_values, result["meta_parameters"]
)
self.config = ModelConfig.objects.create(
project=self.project,
model_version=str(self.project.latest_tag),
meta_parameters_values=save_vals,
meta_parameters=result["meta_parameters"],
model_parameters=result["model_parameters"],
inputs_version="v1",
status="SUCCESS",
)
return {
"meta_parameters": self.config.meta_parameters,
"model_parameters": self.config.model_parameters,
}