-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
105 lines (80 loc) · 2.62 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import pandas as pd
from flask import Flask, request, jsonify
from peewee import (
SqliteDatabase, PostgresqlDatabase, Model, IntegerField,
FloatField, TextField, IntegrityError
)
from playhouse.db_url import connect
# Custom imports
import utils.modelling as md
import data_processing.processing as pc
########################################
# Begin database stuff
DB = connect(os.environ.get('DATABASE_URL') or 'sqlite:///predictions.db')
class Prediction(Model):
observation_id = TextField(unique=True)
observation = TextField()
proba = FloatField()
predict = IntegerField()
true_class = IntegerField(null=True)
class Meta:
database = DB
DB.create_tables([Prediction], safe=True)
# End database stuff
########################################
# Load the model
pipeline, columns, dtypes = md.load_model()
########################################
# Begin webserver stuff
app = Flask(__name__)
@app.route('/should_search/', methods=['POST'])
def predict():
obs_dict = request.get_json()
print("Req received: {}".format(obs_dict))
_id = obs_dict["observation_id"]
obs = pd.DataFrame([obs_dict], columns=columns)
obs_processed = pc.create_time_features(obs)
obs_processed = pc.build_features(obs_processed)
prediction = pipeline.predict(obs_processed)[0]
proba = pipeline.predict_proba(obs_processed)[0, 1]
response = {}
# Save the prediction in the DB
p = Prediction(
observation_id=_id,
proba=proba,
predict=prediction,
observation=obs_dict
)
try:
p.save()
except IntegrityError as e:
error_msg = "ERROR: Observation ID: '{}' already exists".format(_id)
response["error"] = error_msg
print(e)
DB.rollback()
return jsonify(response)
response = {'outcome': bool(prediction)}
print(response)
return jsonify(response)
@app.route('/search_result/', methods=['POST'])
def update():
obs = request.get_json()
print("Req received: {}".format(obs))
try:
p = Prediction.get(Prediction.observation_id == obs['observation_id'])
p.true_class = obs['outcome']
p.save()
response = {
"observation_id": p.observation_id,
"predicted_outcome": bool(p.predict),
"outcome": bool(p.true_class)
}
print(response)
return jsonify(response)
except Prediction.DoesNotExist:
error_msg = 'Observation ID: "{}" does not exist'.format(obs['observation_id'])
print(error_msg)
return jsonify({'error': error_msg})
if __name__ == "__main__":
app.run(debug=True)