-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_sagemaker.py
132 lines (106 loc) · 5.13 KB
/
test_sagemaker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import unittest
import os
import shutil
import json
import inspect
import mxnet as mx
# Import entry and io
with open('config.json') as jfile:
config = json.load(jfile)
model = __import__(config["files"]['entry'])
io = __import__(config["files"]['entry_io'])
class MXNetModelTest(unittest.TestCase):
"""Testcase to check a mxnet model before handing it to sagemaker."""
@classmethod
def setUpClass(self):
# Load parameters
with open('config.json') as jfile:
config = json.load(jfile)
hyperparameters = config['hyperparameters']
num_gpus = config['num_gpus']
num_cpus = config['num_cpus']
hosts = config['hosts']
kwargs = config['kwargs']
# Create directories
self.bucket_path = "bucket"
self.channel_input_dirs = {'train': config['dataset']['train'],
'eval': config['dataset']['eval']}
self.output_data_dir = os.path.join(self.bucket_path, 'data')
self.model_dir = os.path.join(self.bucket_path, 'model')
try:
os.makedirs(self.bucket_path)
except FileExistsError:
print("Removing old bucket at {}.".format(os.path.abspath(self.bucket_path)))
shutil.rmtree(self.bucket_path)
os.makedirs(self.bucket_path)
os.makedirs(self.model_dir)
self.model = model.train(hyperparameters=hyperparameters,
channel_input_dirs=self.channel_input_dirs,
output_data_dir=self.output_data_dir,
model_dir=self.model_dir,
num_gpus=num_gpus,
num_cpus=num_cpus,
hosts=hosts,
**kwargs)
@classmethod
def tearDownClass(self):
shutil.rmtree(self.bucket_path)
pass
# ---------------------------------------------------------------------------- #
# Test signatures #
# ---------------------------------------------------------------------------- #
def test_1_signature(self):
if 'save' in dir(model):
self.assertEqual(str(inspect.signature(model.save)), '(model, model_dir)')
if 'model_fn' in dir(model):
self.assertEqual(str(inspect.signature(model.model_fn)), '(model_dir)')
if 'transform_fn' in dir(model):
self.assertEqual(str(inspect.signature(model.transform_fn)), '(model, input_data, content_type, accept)')
if 'input_fn' in dir(model):
self.assertEqual(str(inspect.signature(model.input_fn)), '(input_data, content_type)')
if 'predict_fn' in dir(model):
self.assertEqual(str(inspect.signature(model.predict_fn)), '(block, array)')
if 'output_fn' in dir(model):
self.assertEqual(str(inspect.signature(model.output_fn)), '(ndarray, accept)')
# ---------------------------------------------------------------------------- #
# Training functions #
# ---------------------------------------------------------------------------- #
def test_2_save(self):
"""
[Optional]
Test save function.
"""
if 'save' in dir(model):
self.assertIsNotNone(self.model, "save function defined but not used")
model.save(model=self.model, model_dir=self.model_dir)
# ---------------------------------------------------------------------------- #
# Hosting functions #
# ---------------------------------------------------------------------------- #
def test_3_transform_fn(self):
"""
[Optional]
Test transform_fn function.
"""
if 'transform_fn' in dir(model):
self.assertTrue('model_fn' in dir(model), 'Currently requires model_fn to test transform_fn')
net = model.model_fn(self.model_dir)
response_body, accept = model.transform_fn(model=net,
input_data=io.input_data,
content_type=io.content_type,
accept=io.accept)
self.assertEqual(accept, io.accept)
if 'response_body' in dir(io):
self.assertEqual(response_body, io.response_body)
else:
print("\nResponse: ", response_body)
# ---------------------------------------------------------------------------- #
# Request handlers for Gluon models #
# ---------------------------------------------------------------------------- #
def test_4_input_predict_output_fn(self):
"""
[Optional]
Test input_fn function.
"""
# if 'input_fn' in dir(model):
# self.assertTrue('model_fn' in dir(model), 'Currently requires model_fn to test transform_fn')
# net = model.model_fn(self.model_dir)