Skip to content

Commit

Permalink
Refactoring and addition of DataManager
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed May 5, 2018
1 parent b1dc75c commit 2c83ab8
Show file tree
Hide file tree
Showing 86 changed files with 194 additions and 30 deletions.
11 changes: 0 additions & 11 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ __pycache__/
*.py[cod]
*$py.class
*.sqlite
*.gz
*.csv
*.pkl
output.txt

Expand Down Expand Up @@ -101,12 +99,3 @@ ENV/

# Rope project settings
.ropeproject
/raw_data_persist.sqlite
/imgs/
/node_modules/
!/static/js/bundle.js
/static/js/bundle.js
/opplett/static/js/bundle.js
opplett/static/js/
!/opplett/static/js/
/opplett/data/data.csv
1 change: 0 additions & 1 deletion hemlock_highway/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@

__version__ = "0.0.1dev"
2 changes: 1 addition & 1 deletion hemlock_highway/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
sys.path.append(PATH)

from hemlock_highway.server import app
from hemlock_highway.config import Config
from hemlock_highway.server.config import Config

config = Config()

Expand Down
6 changes: 6 additions & 0 deletions hemlock_highway/data_manager/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
## Data Manager

Logic to handle the pre-processing of uploaded data

The resulting object which handles the parsing of raw data into a format
suitable for a `Pipeline` object is part of the end model's pre-processing step.
1 change: 1 addition & 0 deletions hemlock_highway/data_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .manager import DataManager
51 changes: 51 additions & 0 deletions hemlock_highway/data_manager/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-

import io
import boto3
import pandas as pd
from hemlock_highway.settings import PROJECT_CONFIG


class DataManager:

X, y = pd.DataFrame(), pd.Series()

def __new__(cls, *args, **kwargs):
cls.s3_client = boto3.client('s3', region_name=PROJECT_CONFIG.AWS_REGION)
return super().__new__(cls)

def __init__(self, data_endpoint: str, target_column: str, **read_args):
"""
Load the head of data stored at either a bucket location, or an http endpoint
Parameters
----------
data_endpoint: str - Either an s3 bucket or an http endpoint
target_column: str - After loading the dataset, this is designated as the target column
read_args: dict - Any additional pandas.read_csv() kwargs
"""
self.data_endpoint = data_endpoint
self.target_column = target_column
self.read_args = read_args

def load(self):
"""
Execute the load from either http or s3
"""
if self.data_endpoint.startswith('http://') or self.data_endpoint.startswith('https://'):
self.X = pd.read_csv(filepath_or_buffer=self.data_endpoint, **self.read_args)
else:
resp = self.s3_client.get_object(Bucket=self.data_endpoint.split('/')[0],
Key='/'.join(self.data_endpoint.split('/')[1:])
)
if resp['ResponseMetadata']['HTTPStatusCode'] == 200:
self.X = pd.read_csv(io.BytesIO(resp['Body'].read()), **self.read_args)
else:
raise IOError(f'Error fetching dataset from S3: {resp}')

self.y = self.X[self.target_column]
self.X = self.X[[col for col in self.X.columns if col != self.target_column]]

@property
def _loaded(self):
return self.X.shape[0] > 0 if hasattr(self.X, 'shape') else False
2 changes: 2 additions & 0 deletions hemlock_highway/ml/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This package contains models suitable for use
by the `model_runner` and `server` resources.
10 changes: 6 additions & 4 deletions hemlock_highway/ml/models/abc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import zlib
import pickle

from hemlock_highway.settings import PROJECT_CONFIG


class AbcHemlockModel:

def __new__(cls, *args, **kwargs):
cls.client = boto3.client('s3', region_name='us-east-1')
cls.s3_client = boto3.client('s3', region_name=PROJECT_CONFIG.AWS_REGION)
return super().__new__(cls)

@staticmethod
Expand All @@ -29,8 +31,8 @@ def dump(self, bucket: str, key: str, name: str):
Dump a model to s3
"""
model_out = zlib.compress(pickle.dumps(self))
self.client.create_bucket(Bucket=bucket)
resp = self.client.put_object(Bucket=bucket, Key=f'{key}/{name}', Body=model_out)
self.s3_client.create_bucket(Bucket=bucket)
resp = self.s3_client.put_object(Bucket=bucket, Key=f'{key}/{name}', Body=model_out)
if resp['ResponseMetadata']['HTTPStatusCode'] == 200:
return True
else:
Expand All @@ -42,7 +44,7 @@ def load(cls, bucket: str, key: str, name: str):
"""
Load a model from S3
"""
model = cls().client.get_object(Bucket=bucket, Key=f'{key}/{name}')['Body'].read()
model = cls().s3_client.get_object(Bucket=bucket, Key=f'{key}/{name}')['Body'].read()
model = pickle.loads(zlib.decompress(model))
return model

Expand Down
3 changes: 3 additions & 0 deletions hemlock_highway/model_runner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Model Runner

API service to take requests to train a model or make predictions
File renamed without changes.
1 change: 1 addition & 0 deletions hemlock_highway/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .server import app
1 change: 1 addition & 0 deletions hemlock_highway/server/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .v1 import api_v1_blueprint
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 7 additions & 7 deletions hemlock_highway/server.py → hemlock_highway/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from flask import Flask, request

from hemlock_highway.config import Config
from hemlock_highway.server.config import Config

# blueprints
from hemlock_highway.api.v1 import api_v1_blueprint
from hemlock_highway.ui import ui_blueprint
from hemlock_highway.user_mgmt import google_auth_blueprint
from hemlock_highway.server.api import api_v1_blueprint
from hemlock_highway.server.ui import ui_blueprint
from hemlock_highway.server.user_mgmt import google_auth_blueprint


app = Flask(__name__)
Expand All @@ -19,6 +19,6 @@
app.register_blueprint(google_auth_blueprint, url_prefix='/google-login')


@app.route('/echo')
def echo():
return request.args.get('word')
@app.route('/health-check')
def health_check():
return 'ahola', 200
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask import redirect, url_for, current_app, render_template
from flask.blueprints import Blueprint
from flask_dance.contrib.google import google
from hemlock_highway.user_mgmt.user import User
from hemlock_highway.server.user_mgmt.user import User

MODULE_PATH = os.path.dirname(os.path.abspath(__file__))

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-

from flask_dance.contrib.google import make_google_blueprint
from hemlock_highway.config import Config
from hemlock_highway.server.config import Config

config = Config()

Expand Down
File renamed without changes.
17 changes: 17 additions & 0 deletions hemlock_highway/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-

import os


class ProjectConfig:

# Directories
REPO_ROOT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')
HEMLOCK_HIGHWAY_MODULE_ROOT_DIR = os.path.join(REPO_ROOT_DIR, 'hemlock_highway')
TEST_ROOT_DIR = os.path.join(REPO_ROOT_DIR, 'tests')
TEST_DATA_DIR = os.path.join(TEST_ROOT_DIR, 'data')

AWS_REGION = 'eu-west-1'


PROJECT_CONFIG = ProjectConfig()
8 changes: 8 additions & 0 deletions tests/data/basic_integer_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
col1,col2,target
1,2,3
2,3,5
3,4,7
4,5,9
5,6,11
6,7,13
7,8,15
Empty file.
58 changes: 58 additions & 0 deletions tests/test_data_manager/test_data_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-

import os
import unittest
import moto

import botocore.exceptions
from .utils import fake_data_on_s3



class DataManagerTestCase(unittest.TestCase):

def setUp(self):

# Test data directory
self.DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')

@fake_data_on_s3(local_dataset='basic_integer_data.csv', bucket='test', key='data/basic.csv')
def test_basic_integer_data_load(self):
"""
Test the basic loading of a dataset.
"""
from hemlock_highway.data_manager import DataManager
dm = DataManager(data_endpoint='test/data/basic.csv', target_column='target')
self.assertFalse(dm._loaded, msg='DataManger should not load data on initialization! Reports it is loaded!')
dm.load()
self.assertTrue(dm._loaded, msg='After asking to load data, DataManager is reporting it is not loaded!')
self.assertTrue(dm.X.shape[0] > 0, msg='DataManager reports it is loaded, but does not have any data in X!')

@moto.mock_s3
def test_load_non_existant_data(self):
"""
Test IOError when trying to load a non-existant dataset.
"""
from hemlock_highway.data_manager import DataManager
dm = DataManager(data_endpoint='test/data/basic.csv', target_column='target')

# Should raise an exception when trying to load a dataset that doesn't exist.
with self.assertRaises(botocore.exceptions.ClientError):
dm.load()

def test_load_from_http(self):
"""
Ensure dataloader can load a dataset via http
"""
from hemlock_highway.data_manager import DataManager
dm = DataManager(
data_endpoint='https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv',
target_column='species'
)
dm.load()
self.assertTrue('petal_length' in dm.X.columns,
msg=f'Expected "petal_length" to be in X, but found {dm.X.columns}')


if __name__ == '__main__':
unittest.main()
19 changes: 19 additions & 0 deletions tests/test_data_manager/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-

import os
import boto3
import moto
from contextlib import contextmanager
from hemlock_highway.settings import PROJECT_CONFIG


@contextmanager
def fake_data_on_s3(local_dataset, bucket, key):

with moto.mock_s3():
s3 = boto3.client('s3', region_name=PROJECT_CONFIG.AWS_REGION)
s3.create_bucket(Bucket=bucket)
with open(os.path.join(PROJECT_CONFIG.TEST_DATA_DIR, local_dataset), 'rb') as f:
s3.put_object(Bucket=bucket, Key=key, Body=f.read())

yield
Empty file added tests/test_ml/__init__.py
Empty file.
File renamed without changes.
Empty file.
Empty file added tests/test_server/__init__.py
Empty file.
13 changes: 10 additions & 3 deletions tests/test_api_v1.py → tests/test_server/test_api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@ def setUp(self):
self.app = app.test_client()

def test_sanity(self):
resp = self.app.get('/echo?word=hello')
self.assertTrue(b'hello' in resp.data)
resp = self.app.get('/health-check')
self.assertTrue(resp.status_code == 200)

def test_available_models(self):
resp = self.app.get('/api/v1/available-models')
self.assertTrue(b'HemlockRandomForestClassifier' in resp.data)
models = json.loads(resp.data)

# Should return a list of strings, each the name of some valid model
self.assertTrue(isinstance(models, list))
self.assertTrue(isinstance(models[-1], str))

# Verify at least this known implemented model is in there.
self.assertTrue('HemlockRandomForestClassifier' in models)

@moto.mock_s3
def test_server_model_dump_load(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import unittest
from hemlock_highway.server import app
from tests.utils import fake_google_authenticated_user
from tests.test_server.utils import fake_google_authenticated_user


class UserMgmtTestCase(unittest.TestCase):
Expand Down
File renamed without changes.

0 comments on commit 2c83ab8

Please sign in to comment.