Skip to content

Commit

Permalink
Merge pull request #14 from dssg/yield_trained_models
Browse files Browse the repository at this point in the history
Allow yielding of trained models [Resolves #11]
  • Loading branch information
k1aus committed Mar 3, 2017
2 parents 9bddc72 + 7c7c717 commit a3ca23b
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 13 deletions.
8 changes: 8 additions & 0 deletions tests/test_model_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,11 @@ def test_model_trainer():
engine.execute('select * from results.feature_importances')
]
assert len(records) == 4 * 3 # maybe exclude entity_id?

# 7. that the generator interface works the same way
new_model_ids = trainer.generate_trained_models(
grid_config=grid_config,
misc_db_parameters=dict()
)
assert expected_model_ids == \
sorted([model_id for model_id in new_model_ids])
4 changes: 4 additions & 0 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@ def test_predictor():
join results.models using (model_id)''')
]
assert len(records) == 4

# 6. That we can delete the model when done prediction on it
predictor.delete_model(model_id)
assert predictor.load_model(model_id) == None
18 changes: 16 additions & 2 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from triage.storage import S3Store, FSStore
from triage.storage import S3Store, FSStore, MemoryStore
from moto import mock_s3
import boto3
import os
Expand All @@ -19,6 +19,8 @@ def test_S3Store():
assert store.exists()
newVal = store.load()
assert newVal.val == 'val'
store.delete()
assert not store.exists()


def test_FSStore():
Expand All @@ -30,4 +32,16 @@ def test_FSStore():
assert store.exists()
newVal = store.load()
assert newVal.val == 'val'
os.remove('tmpfile')
store.delete()
assert not store.exists()


def test_MemoryStore():
store = MemoryStore(None)
assert not store.exists()
store.write(SomeClass('val'))
assert store.exists()
newVal = store.load()
assert newVal.val == 'val'
store.delete()
assert not store.exists()
43 changes: 35 additions & 8 deletions triage/model_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,13 @@ def _get_model_group_id(
logging.debug('Model_group_id = {}'.format(model_group_id))
return model_group_id

def train_models(
def generate_trained_models(
self,
grid_config,
misc_db_parameters,
replace=False
):
"""Train and store configured models
"""Train and store configured models, yielding the ids one by one
Args:
grid_config (dict) of format {classpath: hyperparameter dicts}
Expand All @@ -286,10 +286,8 @@ def train_models(
misc_db_parameters (dict) params to pass through to the database
replace (optional, False): whether to replace already cached models
Returns:
(list) of model ids
Yields: (int) model ids
"""
model_ids = []
misc_db_parameters = copy.deepcopy(misc_db_parameters)
misc_db_parameters['batch_run_time'] = datetime.datetime.now().isoformat()
for class_path, parameters in self._generate_model_configs(grid_config):
Expand All @@ -304,7 +302,7 @@ def train_models(
model_store,
misc_db_parameters
)
model_ids.append(model_id)
yield model_id
else:
logging.info('Skipping %s/%s', class_path, parameters)
session = self.sessionmaker()
Expand All @@ -327,6 +325,35 @@ def train_models(
)
else:
model_id = saved.model_id
model_ids.append(model_id)
yield model_id

def train_models(
self,
grid_config,
misc_db_parameters,
replace=False
):
"""Train and store configured models
Args:
grid_config (dict) of format {classpath: hyperparameter dicts}
example: { 'sklearn.ensemble.RandomForestClassifier': {
'n_estimators': [1,10,100,1000,10000],
'max_depth': [1,5,10,20,50,100],
'max_features': ['sqrt','log2'],
'min_samples_split': [2,5,10]
} }
misc_db_parameters (dict) params to pass through to the database
replace (optional, False): whether to replace already cached models
Returns:
(list) of model ids
"""
return [
model_id for model_id in self.generate_trained_models(
grid_config,
misc_db_parameters,
replace
)
]

return model_ids
23 changes: 20 additions & 3 deletions triage/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import logging


class ModelNotFoundError(ValueError):
pass


class Predictor(object):
def __init__(self, project_path, model_storage_engine, db_engine):
"""Encapsulates the task of generating predictions on an arbitrary
Expand All @@ -20,7 +24,7 @@ def __init__(self, project_path, model_storage_engine, db_engine):
if self.db_engine:
self.sessionmaker = sessionmaker(bind=self.db_engine)

def _load_model(self, model_id):
def load_model(self, model_id):
"""Downloads the cached model associated with a given model id
Args:
Expand All @@ -31,7 +35,18 @@ def _load_model(self, model_id):
"""
model_hash = self.sessionmaker().query(Model).get(model_id).model_hash
model_store = self.model_storage_engine.get_store(model_hash)
return model_store.load()
if model_store.exists():
return model_store.load()

def delete_model(self, model_id):
"""Deletes the cached model associated with a given model id
Args:
model_id (int) The id of a given model in the database
"""
model_hash = self.sessionmaker().query(Model).get(model_id).model_hash
model_store = self.model_storage_engine.get_store(model_hash)
model_store.delete()

def _write_to_db(self, model_id, as_of_date, entity_ids, predictions, labels, misc_db_parameters):
"""Writes given predictions to database
Expand Down Expand Up @@ -84,7 +99,9 @@ def predict(self, model_id, matrix_store, misc_db_parameters):
Returns:
(numpy.Array) the generated prediction values
"""
model = self._load_model(model_id)
model = self.load_model(model_id)
if not model:
raise ModelNotFoundError('Model id {} not found'.format(model_id))
labels = matrix_store.labels()
as_of_date = matrix_store.metadata['end_time']
predictions = model.predict(matrix_store.matrix)
Expand Down
9 changes: 9 additions & 0 deletions triage/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def write(self, obj):
def load(self):
return download_object(self.path)

def delete(self):
self.path.delete()


class FSStore(Store):
def exists(self):
Expand All @@ -43,6 +46,9 @@ def load(self):
with open(self.path, 'rb') as f:
return pickle.load(f)

def delete(self):
os.remove(self.path)


class MemoryStore(Store):
store = None
Expand All @@ -56,6 +62,9 @@ def write(self, obj):
def load(self):
return self.store

def delete(self):
self.store = None


class ModelStorageEngine(object):
def __init__(self, project_path):
Expand Down

0 comments on commit a3ca23b

Please sign in to comment.