Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross Validation Added #407

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
41 changes: 39 additions & 2 deletions gramex/handlers/mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from slugify import slugify
from tornado.gen import coroutine
from tornado.web import HTTPError
from sklearn.metrics import get_scorer
from sklearn.model_selection import cross_val_predict, cross_val_score
from sklearn.model_selection import cross_val_predict, cross_val_score
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line appears twice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extra line is unnecessary.

from ast import literal_eval
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be required.


op = os.path
MLCLASS_MODULES = [
Expand All @@ -40,6 +44,8 @@
'nums': [],
'cats': [],
'target_col': None,
'CV': True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it lowercase.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to support three cases for the cv option:

  • If the user sets cv: false - then no cross validation happens
  • If the user sets cv: 4 (or some other integer) pass it straight to cross_val_score
  • The default should be cv: None, and in this case, the user should not have to write anything in gramex.yaml

'CVargs': []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a single argument, cv, which can take any value, i.e in gramex.yaml, users should be able to write any of the following.

cv: false   # disable cross validation
cv: 5        # Use 5 folds
cv:
  cv: 8   # Use 8 folds
  n_jobs: -1  # with an optional other parameter.

}
ACTIONS = ['predict', 'score', 'append', 'train', 'retrain']
DEFAULT_TEMPLATE = op.join(op.dirname(__file__), '..', 'apps', 'mlhandler', 'template.html')
Expand Down Expand Up @@ -103,7 +109,6 @@ def setup(cls, data=None, model={}, config_dir='', **kwargs):

cls.set_opt('class', model.get('class'))
cls.set_opt('params', model.get('params', {}))

if op.exists(cls.model_path): # If the pkl exists, load it
cls.model = joblib.load(cls.model_path)
elif data is not None:
Expand All @@ -112,14 +117,38 @@ def setup(cls, data=None, model={}, config_dir='', **kwargs):
data = cls._filtercols(data)
data = cls._filterrows(data)
cls.model = cls._assemble_pipeline(data, mclass=mclass, params=params)

# train the model
target = data[target_col]
train = data[[c for c in data if c != target_col]]
# cross validation
print('yayyy we are here')
cls.CrossValidation(train,target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it lowercase.

print('should have printed')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the prints.

gramex.service.threadpool.submit(
_fit, cls.model, train, target, cls.model_path, cls.name)
cls.config_store.flush()

@classmethod
def modelFunction(cls, mclass = ''):
model_kwargs = cls.config_store.load('model', {})
mclass = model_kwargs.get('class', False)
if mclass:
model = search_modelclass(mclass)(**model_kwargs.get('params', {}))
return model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not required.


@classmethod
def CrossValidation(cls,train,target):
mod = cls.modelFunction()
CV = cls.get_opt('CV') #can edit to make CV true/false etc.
if CV:
CVargs = cls.get_opt('CVargs')
if CVargs:
CVscore = cross_val_score(mod, X=train, y=target, **literal_eval(json.dumps(CVargs)))
else:
CVscore = cross_val_score(mod, train, target)
CV = sum(CVscore)/len(CVscore)
print('CV score: ', CV)

@classmethod
def load_data(cls, default=pd.DataFrame()):
try:
Expand Down Expand Up @@ -268,6 +297,10 @@ def _predict(self, data=None, score_col=''):
self.model = cache.open(self.model_path, joblib.load)
try:
target = data.pop(score_col)
metric = self.get_argument('_metric', False)
if metric:
scorer = get_scorer(metric)
return scorer(self.model, data, target)
return self.model.score(data, target)
except KeyError:
# Set data in the same order as the transformer requests
Expand Down Expand Up @@ -347,6 +380,8 @@ def _train(self, data=None):
target = data[target_col]
train = data[[c for c in data if c != target_col]]
self.model = self._assemble_pipeline(data, force=True)
print('IN TRAIN')
self.CrossValidation(train,target)
_fit(self.model, train, target, self.model_path)
return {'score': self.model.score(train, target)}

Expand All @@ -357,6 +392,8 @@ def _score(self):
self._check_model_path()
data = self._parse_data(False)
target_col = self.get_argument('target_col', self.get_opt('target_col'))
print('IN _SCORE')
#self.CrossValidation(data,target_col)
self.set_opt('target_col', target_col)
return {'score': self._predict(data, target_col)}

Expand Down
5 changes: 5 additions & 0 deletions tests/test_mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def test_get_bulk_score(self):
data=self.df.to_json(orient='records'),
headers={'Content-Type': 'application/json'})
self.assertGreaterEqual(resp.json()['score'], self.ACC_TOL)
resp = self.get(
'/mlhandler?_action=score&_metric=f1_weighted', method='post',
data=self.df.to_json(orient='records'),
headers={'Content-Type': 'application/json'})
self.assertGreaterEqual(resp.json()['score'], self.ACC_TOL)

def test_get_cache(self):
df = pd.DataFrame.from_records(self.get('/mlhandler?_cache=true').json())
Expand Down