-
Notifications
You must be signed in to change notification settings - Fork 3
/
package.py
75 lines (62 loc) · 1.83 KB
/
package.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import responses
from server import MetaBanditClassifier
from lib.meta_bandit import MetaBandit
from lib.arm_control import ArmControl
import importlib
import json
from config import Config
def get_policy(config: Config, args):
polity_module = importlib.import_module(args.polity_module)
policy_class = getattr(polity_module, args.polity_cls)
params = {"config": config}
extra_params = config._config['bandit_policy_params']
return policy_class(**{**params, **extra_params})
@responses.activate
def test_server(config: Config, server: MetaBanditClassifier):
# mock Request
for a in list(config.arms.values()):
responses.add(responses.POST, a,
json.dumps({}),
headers={'content-type': 'application/json'},
)
payload = {
"context": {
"f1": 1,
"f2": 0
},
"input": {
"user": 1,
"items": [
0,
1,
3,
7,
4,
6,
5,
2
]
}
}
print(server.predict(payload))
# python pack.py --config-path config.yml --polity-module policy.e_greedy --polity-cls EGreedyPolicy
# bentoml serve MetaBanditClassifier:latest
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--config-path', help='')
parser.add_argument('--polity-module', help='')
parser.add_argument('--polity-cls', help='')
args = parser.parse_args()
print(args)
# Build Model
config = Config(args.config_path)
arm_control = ArmControl(config)
policy_control = get_policy(config, args)
meta_bandit = MetaBandit(config, policy_control, arm_control)
# Package Model
meta_bandit_server = MetaBanditClassifier()
meta_bandit_server.pack("model", meta_bandit)
meta_bandit_server.save()
# Test Model
test_server(config, meta_bandit_server)