Permalink
Browse files

Added ability to save contexts/configs by name

  • Loading branch information...
1 parent b4e0e45 commit 8b3f1c76c326c8c64e35d0cdef50dd351a1980b7 @kvh committed Nov 1, 2012
Showing with 55 additions and 35 deletions.
  1. +1 −0 .gitignore
  2. +15 −3 ramp/context.py
  3. +22 −31 ramp/models.py
  4. +17 −1 ramp/tests/test_configuration.py
View
@@ -1,4 +1,5 @@
*.py[cod]
+*.sw[no]
# C extensions
*.so
View
@@ -5,15 +5,27 @@
class DataContext(object):
- def __init__(self, store, data, train_index=None, prep_index=None):
+ def __init__(self, store, data=None, train_index=None, prep_index=None):
self.store = store
self.data = data
- self.train_index = train_index if train_index is not None else self.data.index
- self.prep_index = prep_index if prep_index is not None else self.data.index
+ self.train_index = train_index if train_index is not None else self.data.index if self.data is not None else None
+ self.prep_index = prep_index if prep_index is not None else self.data.index if self.data is not None else None
def copy(self):
return copy.copy(self)
def create_key(self):
return md5('%s--%s' % (get_np_hashable(self.train_index),
get_np_hashable(self.prep_index))).hexdigest()
+
+ def save_context(self, name, config=None):
+ ctx = {'train_index':self.train_index,
+ 'prep_index':self.prep_index,
+ 'config':config}
+ self.store.save('context__%s' % name, ctx)
+
+ def load_context(self, name):
+ ctx = self.store.load('context__%s' % name)
+ self.train_index = ctx['train_index']
+ self.prep_index = ctx['prep_index']
+ return ctx['config']
View
@@ -14,13 +14,17 @@
hard to implement (so many variables/data to key on, are they all really immutable?)
"""
-def get_xy(config, context):
+def get_x(config, context):
x = build_featureset(config.features, context)
- y = build_target(config.target, context)
-
if config.column_subset:
x = x[config.column_subset]
- return x, y
+ return x
+
+def get_y(config, context):
+ return build_target(config.target, context)
+
+def get_xy(config, context):
+ return get_x(config, context), get_y(config, context)
def get_key(config, context):
@@ -42,7 +46,7 @@ def fit(config, context):
if debug:
print train_x
if debug:
- print "Fitting model '%s'." % (config.model.__name__)
+ print "Fitting model '%s'." % (config.model)
config.model.fit(train_x.values, train_y.values)
context.store.save(get_key(config, context), config.model)
@@ -56,6 +60,8 @@ def predict(config, context, predict_index, fit_model=True):
if (context.train_index & predict_index):
print "WARNING: train and predict indices overlap..."
+ x, y = None, None
+
if fit_model:
x, y = fit(config, context)
@@ -64,19 +70,17 @@ def predict(config, context, predict_index, fit_model=True):
# rebuild just the necessary x:
ctx = context.copy()
ctx.data = context.data.ix[predict_index]
- x, y = get_xy(config, ctx)
-
- # ensure correct columns exist:
-# for col in columns_used:
-# if col not in x.columns:
-# print "WARNING: filling missing column '%s' with zeros" % col
-# x[col] = Series(np.random.randn(len(x)) / 100, index=x.index)
-# symdif = set(x.columns) ^ set(columns_used)
-# if symdif:
-# print symdif
-# raise Exception("mismatched columns between fit and predict.")
- # re-order columns
-# x = x.reindex(columns=columns_used)
+ x = get_x(config, ctx)
+ try:
+ # we may or may not have y's in predict context
+ # we get them if we can for metrics and reporting
+ y = get_y(config, ctx)
+ except KeyError:
+ pass
+
+ if debug:
+ print x.columns
+ print config.model.coef_
predict_x = x.reindex(predict_index)
@@ -116,16 +120,3 @@ def print_scores(scores):
scores.mean(), scores.std(), min(scores),
max(scores))
-
-def build_model(config, context, name=None):
- models.fit(config, context)
- context.store.save('model__%s' % name, get_key(config, context))
-
-
-def get_or_build_model(config, context, name):
- try:
- key = context.store.load('model__%s' % name)
- config.model = context.store.load(key)
- except KeyError:
- build_model(config, context, name)
-
@@ -2,18 +2,20 @@
sys.path.append('../..')
from ramp.configuration import *
from ramp.features import *
+from ramp.features.base import *
from ramp.models import *
from ramp.metrics import *
import unittest
import pandas
from sklearn import linear_model
import numpy as np
-import os, sys, random
+import os, sys, random, pickle
from pandas.util.testing import assert_almost_equal
class ConfigurationTest(unittest.TestCase):
+
def test_config_factory(self):
base = Configuration(
features=['a'],
@@ -33,6 +35,20 @@ def test_config_factory(self):
cnfs = [cnf for cnf in fact]
self.assertEqual(len(cnfs), 4)
+ def test_configuration_pickle(self):
+ c = Configuration(
+ features=['a', F('a'), Map('a', len)],
+ model=linear_model.LogisticRegression()
+ )
+ s = pickle.dumps(c)
+ c2 = pickle.loads(s)
+ self.assertEqual(repr(c), repr(c2))
+
+ # lambdas are not picklable, should fail
+ c = Configuration(
+ features=['a', F('a'), Map('a', lambda x: len(x))],
+ )
+ self.assertRaises(pickle.PicklingError, pickle.dumps, c)
if __name__ == '__main__':
unittest.main()

0 comments on commit 8b3f1c7

Please sign in to comment.