-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
69 lines (55 loc) · 2.06 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from logging import INFO, DEBUG, StreamHandler, getLogger
from sys import stdout
from cv2 import resize
from dlr import DLRModel
from numpy import argsort, frombuffer, fromstring, load, uint8
from os import path
from ast import literal_eval
import os
logger = getLogger()
logger.setLevel(INFO)
logging_handler = StreamHandler(stdout)
logger.addHandler(logging_handler)
SCORE_THRESHOLD = 0.3
MAX_NO_OF_RESULTS = 5
SHAPE = (224, 224)
#MODEL_DIR = f'{os.getcwd()}/model'
MODEL_DIR = '/greengrass/v2/packages/artifacts-unarchived/variant.DLR.ImageClassification.ModelStore/2.1.9/DLR-resnet50-x86_64-cpu-ImageClassification'
logger.info('MODEL_DIR: %s', MODEL_DIR)
# Read synset file
LABELS = path.join(MODEL_DIR, "synset.txt")
with open(LABELS, "r") as f:
synset = literal_eval(f.read())
def load_model(model_dir):
model = DLRModel(model_dir, dev_type='cpu', use_default_dlr=False)
logger.debug('MODEL was loaded')
return model
def predict_from_image(model, image_data):
result = []
try:
# Run DLR to perform inference with DLC optimized model
model_output = model.run(image_data)
probabilities = model_output[0][0]
sort_classes_by_probability = argsort(probabilities)[::-1]
for i in sort_classes_by_probability[: MAX_NO_OF_RESULTS]:
if probabilities[i] >= SCORE_THRESHOLD:
result.append({"Label": str(synset[i]), "Score": str(probabilities[i])})
logger.debug("result: {}".format(result))
return result
except Exception as e:
logger.error("Exception occured during prediction: {}".format(e))
model = load_model(MODEL_DIR)
def handler(event, context):
logger.debug('event: %s', event)
image_data = event['body']
cvimage = resize(image_data, SHAPE)
if cvimage is not None:
result = predict_from_image(model, cvimage)
logger.debug('result: %s', result)
return {
'statusCode': 200,
'body': result
}
else:
logger.error("Unable to capture an image using camera")
exit(1)