-
Notifications
You must be signed in to change notification settings - Fork 19
/
fast_predict.py
41 lines (31 loc) · 1.44 KB
/
fast_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
Speeds up estimator.predict by preventing it from reloading the graph on each call to predict.
It does this by creating a python generator to keep the predict call open.
Usage: Just warp your estimator in a FastPredict. i.e.
classifier = FastPredict(learn.Estimator(model_fn=model_params.model_fn, model_dir=model_params.model_dir))
NOTE!! There's a new version of fast predict which supports tf 1.4 and later. See fast_predict2.py.
Author: Marc Stogaitis
"""
class FastPredict:
def _createGenerator(self):
while not self.closed:
yield self.next_features
def __init__(self, estimator):
self.estimator = estimator
self.first_run = True
self.closed = False
def predict(self, features):
self.next_features = features
if self.first_run:
self.batch_size = len(features)
self.predictions = self.estimator.predict(x = self._createGenerator(), batch_size=None)
self.first_run = False
elif self.batch_size != len(features):
raise ValueError("All batches must be of the same size. First-batch:" + str(self.batch_size) + " This-batch:" + str(len(features)))
results = []
for _ in range(self.batch_size):
results.append(next(self.predictions))
return results
def close(self):
self.closed=True
next(self.predictions)