In [22]:
import numpy as np
from collections import defaultdict


feature_list = ['goout', 'school', 'schoolsup', 'failures', 'higher', 'subject', 'intercept']
    
class Model(object):

    def __init__(self, model_path='baseline_weights.npy'):
        self.load_weights(model_path)

    def load_weights(self, path):
        with open(path, 'rb') as f:
            model_weights = np.load(f)
        self.check_weights(model_weights)
        self.weights = model_weights
    
    def check_weights(self, weights):
        assert weights.shape == (3, len(feature_list))

    def predict(self, X):
        assert(isinstance(X, Student))
        prediction = np.dot(self.weights, X)
        return {'G1':prediction[0], 'G2':prediction[1], 'G3':prediction[2]}

class Student(list):
    
    def __init__(self, **kwargs):
        student_ = defaultdict(lambda : 0)
        for feature_name, feature_value in kwargs.items():
            assert(feature_name in feature_list)
            student_[feature_name] = float(feature_value[0])
        for feature_name in feature_list[:-1]:
            self.append(student_[feature_name])
        self.append(1)
    


In [23]:
from flask import Flask, request
from flask_restful import Resource, Api

from sqlalchemy import create_engine
from json import dumps
from flask.ext.jsonpify import jsonify

app = Flask(__name__)
api = Api(app)

model = Model()

class Predict(Resource):
    
    def __init__(self):
        print(request.args)
        self.student = Student(**request.args)

    def get(self, **kwargs):
        print(kwargs)
        result = model.predict(self.student)
        return jsonify(result)
        

class WeightUpdate(Resource):
    def get(self, weights):
        model.load_weights(weights)

api.add_resource(Predict, '/predict/') # Route_3

In [None]:
app.run(port='5002')
     

 * Running on http://127.0.0.1:5002/ (Press CTRL+C to quit)
127.0.0.1 - - [16/Jan/2018 23:48:33] "GET /predict/?goout=3&failures=4 HTTP/1.1" 200 -


ImmutableMultiDict([('goout', '3'), ('failures', '4')])
{}


127.0.0.1 - - [16/Jan/2018 23:48:40] "GET /predict/?goout=3&failures=3 HTTP/1.1" 200 -


ImmutableMultiDict([('goout', '3'), ('failures', '3')])
{}
