Skip to content

Commit

Permalink
Add ModelRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed May 7, 2018
1 parent 08049d3 commit 7b2c629
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 22 deletions.
1 change: 1 addition & 0 deletions hemlock_highway/data_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def load(self, n_bytes=None):
# Separate X and y
self.y = self.X[self.target_column]
self.X = self.X[[col for col in self.X.columns if col != self.target_column]]
return self.X, self.y

def _load_from_http(self, n_bytes: int=None) -> io.BytesIO:
# Stream from source writing by chunk size of 1mb or n_bytes
Expand Down
4 changes: 2 additions & 2 deletions hemlock_highway/ml/models/RandomForest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-

from sklearn.ensemble import RandomForestClassifier
from hemlock_highway.ml.models import AbcHemlockModel
from hemlock_highway.ml.models import HemlockModelBase


class HemlockRandomForestClassifier(RandomForestClassifier, AbcHemlockModel):
class HemlockRandomForestClassifier(RandomForestClassifier, HemlockModelBase):

@staticmethod
def configurable_parameters():
Expand Down
4 changes: 2 additions & 2 deletions hemlock_highway/ml/models/abc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hemlock_highway.data_manager import DataManager


class AbcHemlockModel:
class HemlockModelBase:

# Each model should have a DataManager to manage the handling of the IO/parsing of data for the model.
data_manager = DataManager('', '')
Expand Down Expand Up @@ -54,5 +54,5 @@ def load(cls, bucket: str, key: str):


__all__ = [
'AbcHemlockModel'
'HemlockModelBase'
]
1 change: 1 addition & 0 deletions hemlock_highway/model_runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .app import app
from .runner import ModelRunner
13 changes: 5 additions & 8 deletions hemlock_highway/model_runner/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from flask import Flask, request, jsonify
from flask.views import MethodView
from hemlock_highway.settings import PROJECT_CONFIG
from hemlock_highway.ml.models import AbcHemlockModel
from hemlock_highway.model_runner.runner import ModelRunner

app = Flask(__name__,
root_path=os.path.dirname(__file__))

S3_CLIENT = boto3.client('s3', region_name=PROJECT_CONFIG.AWS_REGION)


class ModelRunner(MethodView):
class ModelRunnerAPI(MethodView):

methods = ['POST']

Expand All @@ -24,17 +24,14 @@ def post(self):
model_location = request.form.get('model-location')
bucket = model_location.split('/')[0]
key = '/'.join(model_location.split('/')[1:])
model = AbcHemlockModel.load(bucket=bucket, key=key) # type: AbcHemlockModel
if model.data_manager is None or not model.data_manager.data_endpoint:
runner = ModelRunner(bucket=bucket, key=key) # type: ModelRunner
if runner.model.data_manager is None or not runner.model.data_manager.data_endpoint:
return jsonify({'success': False,
'message': f'Loaded model at {model_location} does not have a valid DataManager'})
model.data_manager.load()
X, y = model.data_manager.X, model.data_manager.y
model.fit(X, y)
return jsonify({'success': True, 'job-id': str(hashlib.md5(str(datetime.now()).encode('utf-8')))})


app.add_url_rule(rule='/train-model', view_func=ModelRunner.as_view('model_runner'))
app.add_url_rule(rule='/train-model', view_func=ModelRunnerAPI.as_view('model_runner'))


if __name__ == '__main__':
Expand Down
29 changes: 29 additions & 0 deletions hemlock_highway/model_runner/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-

from hemlock_highway.ml import HemlockModelBase


class ModelRunner:
"""
Responsible for running a model in either training/testing/prediction modes
"""

def __init__(self, model: HemlockModelBase=None, bucket: str=None, key: str=None):
"""
Load the model, but not the data using either directly passing model or pointing to s3 location
"""
if model is None and all((m is None for m in (bucket, key))):
raise ValueError('Must specify either a loaded model or an s3 location specifying both bucket and key!')

self.model = model if model is not None else HemlockModelBase.load(bucket, key)

def fit(self):
"""Fit the underlying model to the data from it's DataManager"""
X, y = self.model.data_manager.load()
self.model.fit(X, y)
return True

def predict(self):
"""Predict using model from DataManager data"""
X, y = self.model.data_manager.load()
return self.model.predict(X)
12 changes: 6 additions & 6 deletions hemlock_highway/server/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import inspect
from flask import Blueprint, jsonify, request, make_response, Response
from hemlock_highway.ml.models import AbcHemlockModel
from hemlock_highway.ml.models import HemlockModelBase
from hemlock_highway.ml import models
from typing import Union

Expand All @@ -17,7 +17,7 @@ def available_models():
"""
_models = [
model for model in filter(lambda m: m.startswith('Hemlock'), dir(models))
if issubclass(getattr(models, model), AbcHemlockModel)
if issubclass(getattr(models, model), HemlockModelBase)
]
return jsonify(_models)

Expand All @@ -31,10 +31,10 @@ def dump_model():

Model = get_model_by_name(model_name)

if inspect.isclass(Model) and issubclass(Model, AbcHemlockModel):
if inspect.isclass(Model) and issubclass(Model, HemlockModelBase):
# Initialize the model and dump it to the s3 bucket
# TODO: Parameterize the dumping location.
model = Model(**model_conf) # type: AbcHemlockModel
model = Model(**model_conf) # type: HemlockModelBase
model.dump(bucket='hemlock-highway-test', key='tests/model.pkl')
return jsonify({'success': True})
else:
Expand All @@ -48,13 +48,13 @@ def model_parameters():

Model = get_model_by_name(model_name)

if inspect.isclass(Model) and issubclass(Model, AbcHemlockModel):
if inspect.isclass(Model) and issubclass(Model, HemlockModelBase):
return jsonify({'success': True, 'parameters': Model.configurable_parameters()})
else:
return Model


def get_model_by_name(model_name: str) -> Union[AbcHemlockModel, Response]:
def get_model_by_name(model_name: str) -> Union[HemlockModelBase, Response]:

# If the name is None, we can't proceed
if model_name is None:
Expand Down
2 changes: 2 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ coveralls==1.3.0
Flask-Dance==0.14.0
SQLAlchemy==1.2.7
SQLAlchemy-Utils==0.33.3
fakeredis==0.10.2
redis==2.10.6
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ docker-pycreds==0.2.3 # via docker
docker==3.3.0 # via moto
docopt==0.6.2 # via coveralls
docutils==0.14 # via botocore
fakeredis==0.10.2
first==2.0.1 # via pip-tools
flask-dance==0.14.0
flask==1.0.1
Expand Down Expand Up @@ -49,6 +50,7 @@ pytest==3.5.1
python-dateutil==2.6.1 # via botocore, moto, pandas
pytz==2018.4 # via moto, pandas
pyyaml==3.12 # via pyaml
redis==2.10.6
requests-oauthlib==0.8.0 # via flask-dance
requests==2.18.4 # via aws-xray-sdk, coveralls, docker, flask-dance, moto, requests-oauthlib, responses
responses==0.9.0 # via moto
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_manager/test_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import moto

from hemlock_highway.data_manager import DataManager
from .utils import fake_data_on_s3
from tests.utils import fake_data_on_s3


class DataManagerTestCase(unittest.TestCase):
Expand Down
55 changes: 52 additions & 3 deletions tests/test_model_runner/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,73 @@
import json
import unittest
import responses
from sklearn.exceptions import NotFittedError
from hemlock_highway.data_manager import DataManager
from hemlock_highway.model_runner import app
from hemlock_highway.ml.models import HemlockRandomForestClassifier
from hemlock_highway.model_runner import app, ModelRunner
from hemlock_highway.ml.models import HemlockRandomForestClassifier, HemlockModelBase
import moto

from tests.utils import fake_data_on_s3

class ModelRunnerTestCase(unittest.TestCase):

def setUp(self):
app.testing = True
self.app = app.test_client()

@moto.mock_s3
def test_model_runner_loader(self):
"""
Model dumped to s3, should be able to be instantiated through the ModelRunner class
"""
# Test dumping model to s3 and then loading it back via ModelRunner
clf1 = HemlockRandomForestClassifier()
clf1.dump(bucket='test', key='mymodel.pkl')
clf2 = ModelRunner(bucket='test', key='mymodel.pkl')
self.assertIsInstance(clf2.model, HemlockModelBase)

with self.assertRaises(ValueError):
ModelRunner() # Fail if not passing an existing model or an s3 location

# Test passing model directly to ModelRunner
clf2 = ModelRunner(model=clf1)
self.assertIsInstance(clf2.model, HemlockModelBase)

@fake_data_on_s3(local_dataset='iris.csv', bucket='test', key='iris.csv')
def test_model_runner_process(self):
"""
Test core process of loading data, fitting and making predictions using underlying model
"""
# Define some model with it's data manager
clf = HemlockRandomForestClassifier()
clf.data_manager = DataManager(data_endpoint='test/iris.csv', target_column='species')
clf.data_manager.load()

# Pass model to ModelRunner
runner = ModelRunner(clf)

# Model isn't fitted, so it shouldn't be able to predict anything
with self.assertRaises(NotFittedError, msg="Model isn't fitted, so it shouldn't be able to predict anything!"):
runner.predict()

# Fit & predict, ensuring that the orignal, runner, and predicted data sizes match
runner.fit()
data = runner.predict()
original_size = clf.data_manager.X.shape[0]
runner_size = runner.model.data_manager.X.shape[0]
predicted_size = data.shape[0]
self.assertTrue(
original_size == runner_size == predicted_size,
f'Expected: '
f' original data ({original_size}) == runner data ({runner_size}) == predicted data ({predicted_size})'
)

@moto.mock_s3
def test_load_model(self):
"""
Test that a dumped model can be loaded from ModelRunner server
"""
responses.add_passthru('https://')
responses.add_passthru('https://') # mock_s3 breaks requests
clf = HemlockRandomForestClassifier()
clf.data_manager = DataManager(
data_endpoint='https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv',
Expand Down
File renamed without changes.

0 comments on commit 7b2c629

Please sign in to comment.