Skip to content
This repository has been archived by the owner on Mar 1, 2018. It is now read-only.

Commit

Permalink
Fixing load supervised model
Browse files Browse the repository at this point in the history
PEP8 over the changes
  • Loading branch information
silviodc committed Aug 12, 2017
1 parent 9ce573f commit 4060063
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import unicodedata
import tempfile
import shutil
from io import BytesIO
from urllib.request import urlopen
Expand All @@ -16,8 +15,7 @@
from PIL import Image as pil_image
from sklearn.base import TransformerMixin
from wand.image import Image
from PIL import Image as pil_image
from io import BytesIO


class MealGeneralizationClassifier(TransformerMixin):
"""
Expand All @@ -44,7 +42,6 @@ class MealGeneralizationClassifier(TransformerMixin):
# Dimensions of our images.
img_width, img_height = 300, 300


# It defines how many iterations will run to find the best model during traaining
epochs = 20
# It influences the speed of your learning (Execution)
Expand Down Expand Up @@ -93,6 +90,7 @@ def train(self, train_data_dir, validation_data_dir, save_dir):
# Inspired by biological visual cortex and tailored for computer vision tasks.
# Authour: Yann LeCun in early 1990s.
# See http://deeplearning.net/tutorial/lenet.html for introduction.
# Or this simplified version: https://www.youtube.com/watch?v=JiN9p5vWHDY

# This is the augmentation configuration we will use for training
model.compile(loss='binary_crossentropy',
Expand Down Expand Up @@ -124,7 +122,7 @@ def train(self, train_data_dir, validation_data_dir, save_dir):

# It allow us to save only the best model between the iterations
checkpointer = ModelCheckpoint(
filepath=os.path.join(save_dir,"weights.hdf5"),
filepath=os.path.join(save_dir, "weights.hdf5"),
verbose=1, save_best_only=True)

# We set it as a parameter to save only the best model
Expand All @@ -138,7 +136,7 @@ def train(self, train_data_dir, validation_data_dir, save_dir):

def fit(self, X):
# Load an existent Keras model
if (not os.path.isfile(X) and (isinstance(X, str) and ( 'https' in X or 'http' in X))):
if (not os.path.isfile(X) and (isinstance(X, str) and ('https' in X or 'http' in X))):
response = urlopen(X)
with open('weights.hdf5', 'wb') as fp:
shutil.copyfileobj(response, fp)
Expand Down Expand Up @@ -174,13 +172,13 @@ def predict(self, X):
preds = self.keras_model.predict_classes(x, verbose=0)
# Get the probability of prediciton
prob = self.keras_model.predict_proba(x, verbose=0)
# Keep the predictions with more than 80% of accuracy and the class 1 (suspicious)
# Keep the predictions = (suspicious)
if(prob >= 0.8 and preds == 1):
result.append(True)
else:
result.append(False)
else:
# Case the reimbursement can not be downloaded or convert it is classified as False
# Case the reimbursement can not be convereted to png
result.append(False)

self._X['y'] = result
Expand All @@ -189,7 +187,8 @@ def predict(self, X):
def __applicable_rows(self, X):
return (X['category'] == 'Meal')

""" Creates a new column 'links' containing an url for the files in the chamber of deputies website
""" Creates a new column 'links' containing an url
for the files in the chamber of deputies website
Return updated Dataframe
arguments:
Expand Down
4 changes: 2 additions & 2 deletions rosie/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def __call__(self):
def load_trained_model(self, classifier):
filename = '{}.pkl'.format(classifier.__name__.lower())
path = os.path.join(self.data_path, filename)

keys = self.settings.SUPERVISED_MODEL
# palliative: this outputs a model too large for joblib
if classifier.__name__ == 'MonthlySubquotaLimitClassifier':
model = classifier()
model.fit(self.dataset)
elif classifier.__name__ in self.settings.SUPERVISED_MODEL:
elif (keys is not None and classifier.__name__ in keys):
model = classifier()
model.fit(self.settings.SUPERVISED_MODEL[classifier.__name__])
else:
Expand Down
3 changes: 2 additions & 1 deletion rosie/core/tests/test_core_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def test_load_trained_model_for_meal_generalization(self):
core = Core(settings, self.adapter)
core.load_trained_model(ClassifierClass)

classifier_instance.fit.assert_called_once_with(settings.SUPERVISED_MODEL['MealGeneralizationClassifier'])
classifier_instance.fit.assert_called_once_with(
settings.SUPERVISED_MODEL['MealGeneralizationClassifier'])

def test_predict(self):
model = MagicMock()
Expand Down

0 comments on commit 4060063

Please sign in to comment.