Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save Scikit-Learn-Keras Model into a Persistence File (pickle/hd5/json/yaml) #4274

Closed
gundalav opened this issue Nov 3, 2016 · 17 comments

Comments

@gundalav
Copy link

gundalav commented Nov 3, 2016

I have the following code, using Keras Scikit-Learn Wrapper:

from keras.models import Sequential
from sklearn import datasets
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn import preprocessing
import pickle
import numpy as np
import json

def classifier(X, y):
    """
    Description of classifier
    """
    NOF_ROW, NOF_COL =  X.shape

    def create_model():
        # create model
        model = Sequential()
        model.add(Dense(12, input_dim=NOF_COL, init='uniform', activation='relu'))
        model.add(Dense(6, init='uniform', activation='relu'))
        model.add(Dense(1, init='uniform', activation='sigmoid'))
        # Compile model
        model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        return model

    # evaluate using 10-fold cross validation
    seed = 7
    np.random.seed(seed)
    model = KerasClassifier(build_fn=create_model, nb_epoch=150, batch_size=10, verbose=0)
    return model
    

def main():
    """
    Description of main
    """

    iris = datasets.load_iris()
    X, y = iris.data, iris.target
    X = preprocessing.scale(X)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
    model_tt = classifier(X_train, y_train)
    model_tt.fit(X_train,y_train)

    #--------------------------------------------------
    # This fail
    #-------------------------------------------------- 
    filename = 'finalized_model.sav'
    pickle.dump(model_tt, open(filename, 'wb'))
    # load the model from disk
    loaded_model = pickle.load(open(filename, 'rb'))
    result = loaded_model.score(X_test, Y_test)
    print(result)
    
    #--------------------------------------------------
    # This also fail
    #--------------------------------------------------
    # from keras.models import load_model       
    # model_tt.save('test_model.h5')
    

    #--------------------------------------------------
    # This works OK 
    #-------------------------------------------------- 
    # print model_tt.score(X_test, y_test)
    # print model_tt.predict_proba(X_test)
    # print model_tt.predict(X_test)


if __name__ == '__main__':
    main()

As stated in the code there it fails at this line:

pickle.dump(model_tt, open(filename, 'wb'))

With this error:
pickle.PicklingError: Can't pickle <function create_model at 0x101c09320>: it's not found as __main__.create_model
How can I get around it?

@krishnateja614
Copy link

Can you try running model_tt.model.save("test_model.h5")? . I think we can't directly use save function on scikit learn wrapper but the above line should hopefully do what you want to do. Let me know

@MaxPowerWasTaken
Copy link

MaxPowerWasTaken commented Feb 9, 2017

Thanks @krishnateja614 this worked for me where I had the same issue. Wish this were the accepted answer on Stack Overflow

@cbrummitt
Copy link

If sklearn.model_selection.GridSearchCV is wrapped around a KerasClassifier or KerasRegressor, then that GridSearchCV object (call it gscv) cannot be pickled. Instead, it looks like we can only save the best estimator using:

gscv.best_estimator_.model.save('filename.h5')

Is there a way to save the whole GridSearchCV object?

I guess I could write a function save_grid_search_cv(model, filename) that

  • pickles everything in model.__dict__ that's not a KerasRegressor nor KerasClassifier and
  • calls save on all KerasRegressor and KerasClassifier objects.
    And then write a corresponding load_grid_search_cv(filename) function. I'm comparing Keras models with sklearn models, so I'd like to save both kinds of models (in GridSearchCV objects) using one function.

@crbrinton
Copy link

crbrinton commented Apr 13, 2017

When trying to persist a KerasClassifier (or KerasRegressor) object, the KerasClassifier itself does not have a save method. It is the keras model that is wrapped by the KerasClassifier that can be saved using the save method. However, if you want to end up with a KerasClassifier after re-loading the persisted model, the re-loaded model must be wrapped anew in the KerasClassifier. This can be done by creating a new KerasClassifier object with a build_fn that actually calls the load_model method, such as:

def build_by_loading(self):
    model = load_model('nn_model.h5')
    return model 

So the KerasClassifier to be re-instantiated from a persisted file would be created as follows (for example):

    nn_model = KerasClassifier(build_fn=self.build_by_loading, nb_epoch=10, batch_size=5, verbose=1)

Unfortunately, the KerasClassifier code does not call the build_fn until the 'fit' method of the KerasClassifier is called. This would defeat the purpose of persisting the model.

I created a 'build_only' method in KerasClassifier that only calls the build_fn, but does not fit the model. This worked for me. I recommend that some means of instantiating a KerasClassifier from a persisted keras model similar to this be included in the next release.

@stale
Copy link

stale bot commented Jul 13, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@michaelTUM
Copy link

michaelTUM commented Jul 22, 2017

@crbrinton Hey bro after looking for many days for a solution I've found your comment and it seems to make sense! However I am not sure how to implement it for my problem.

I am trying to apply the skelarn.BaggingClassifier on the KerasClassifier. If I run it on one core there is no problem at all, but it fails if I want to use multicores. I pinned the problem down to that the BaggingClassifier tries to save the KerasClassifier at some point (which is not possible out-of-the-box as you pointed out).

I tried to modify this solution http://zachmoshe.com/2017/04/03/pickling-keras-models.html for my needs, but without success. Do you have an idea? Thanks

I've opened a stackoverflow question about this problem https://stackoverflow.com/questions/45231354/sklearn-baggingclassifier-n-jobs-1-keras

@stale stale bot removed the stale label Jul 22, 2017
@stale
Copy link

stale bot commented Oct 20, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@stale stale bot added the stale label Oct 20, 2017
@stale stale bot closed this as completed Nov 19, 2017
@ChaymaZatout
Copy link

@cbrummitt This "gscv.best_estimator_.model.save('filename.h5')" works for me :)

@shaoeChen
Copy link

@ChimMeya thank you very much, i use the best_estimator_ got the model infomation.

@ikinzMartin
Copy link

Hey guys, i'm a bit late to the party but there is a quick and dirty fix to this issue.

In my case, i'm using a KerasClassifier inside of a Pipeline as such:

model = Pipeline([
   ('cleaner', TextCleaner()),
   ('encoder', KerasEncoder(...)),
   ('lstm', KerasClassifier(build_fn=..., ...))
])

In order to serialize it i use the following function :

def save_pipeline_keras(model,folder_name="model"):
    os.makedirs(folder_name, exist_ok=True)
    pickle.dump(model.named_steps['cleaner'], open(folder_name+'/'+'cleaner.pkl','wb'))
    pickle.dump(model.named_steps['encoder'], open(folder_name+'/'+'encoder.pkl','wb'))
    pickle.dump(model.named_steps['lstm'].classes_, open(folder_name+'/'+'classes.pkl','wb'))
    model.named_steps['lstm'].model.save(folder_name+'/lstm.h5')

Which saves all the individual components of the Pipeline in a folder /model.

In order to load it i use the following:

def load_pipeline_keras(cleaner, encoder, model, classes, folder_name="model"):
    cleaner = pickle.load(open(folder_name+'/'+cleaner,'rb'))
    encoder = pickle.load(open(folder_name+'/'+encoder,'rb'))
    build_model = lambda: load_model(folder_name+'/'+model)
    classifier = KerasClassifier(build_fn=build_model, epochs=1, batch_size=10, verbose=1)
    classifier.classes_ = pickle.load(open(folder_name+'/'+classes,'rb'))
    classifier.model = build_model()
    return Pipeline([
        ('cleaner', cleaner),
        ('encoder', encoder),
        ('lstm', classifier)
    ])

The quick and dirty part of this is that the KerasClassifier object calls the build_fn function only when the fit method of the Pipeline object is called. So in order to get around this problem you can manually set the KerasClassifier.model and KerasClassifier.classes_ attributes yourself to their corresponding values.

Hope this helps someone !

@weizhu365
Copy link

I have run into the same issue that I cannot restore the picked GridSearchCV object. But I figured out a simple solution: use "dill" rather than "pickle" instead.

So you may simply change one line of code:
import dill as pickle

The problem should be fixed.

Cheers,

Wei

@Permafacture
Copy link

I have a work around for the SKLearn pipeline on python 2.7 over at #13168 . It would be easy to generalize if someone needs it more general.

@Permafacture
Copy link

General solution for Tensorflow models at least. Pickleable with joblib

class PickleableKerasClassifier(KerasClassifier):

    def __getstate__(self):
        state = self.__dict__
        model = state['model']
        bio = io.BytesIO()
        with h5py.File(bio) as f:
            model.save(f)
        state['model'] = bio
        return_state = deepcopy(state)
        state['model'] = model
        return return_state

    def __setstate__(self, state):
        with h5py.File(state['model']) as f:
            state['model'] = load_model(f)
        self.__dict__ = state

@nunoachenriques
Copy link

nunoachenriques commented Aug 16, 2019

General solution for Tensorflow models at least. Pickleable with joblib

class PickleableKerasClassifier(KerasClassifier):

    def __getstate__(self):
        state = self.__dict__
        model = state['model']
        bio = io.BytesIO()
        with h5py.File(bio) as f:
            model.save(f)
        state['model'] = bio
        return_state = deepcopy(state)
        state['model'] = model
        return return_state

    def __setstate__(self, state):
        with h5py.File(state['model']) as f:
            state['model'] = load_model(f)
        self.__dict__ = state

Thanks for sharing! I've added the case when "model" isn't yet available before "build_fn" call by fit(). I'm using scikit-optimize BayesSearchCV(). Now I can use pickle.dumps(estimator_search.best_estimator_)

import copy
import io

import dill as pickle  # Using dill instead of Python 3.7+ pickle
import h5py
import tensorflow as tf

...

class KerasClassifier(tf.keras.wrappers.scikit_learn.KerasClassifier):
    """
    TensorFlow Keras API neural network classifier.

    Workaround the tf.keras.wrappers.scikit_learn.KerasClassifier serialization
    issue using BytesIO and HDF5 in order to enable pickle dumps.

    Adapted from: https://github.com/keras-team/keras/issues/4274#issuecomment-519226139
    """

    def __getstate__(self):
        state = self.__dict__
        if "model" in state:
            model = state["model"]
            model_hdf5_bio = io.BytesIO()
            with h5py.File(model_hdf5_bio, mode="w") as file:
                model.save(file)
            state["model"] = model_hdf5_bio
            state_copy = copy.deepcopy(state)
            state["model"] = model
            return state_copy
        else:
            return state

    def __setstate__(self, state):
        if "model" in state:
            model_hdf5_bio = state["model"]
            with h5py.File(model_hdf5_bio, mode="r") as file:
                state["model"] = tf.keras.models.load_model(file)
        self.__dict__ = state

@cwindolf
Copy link

Thanks @Permafacture and @nunoachenriques!

I think the patch in the previous comment should be merged into KerasClassifier, or maybe into its base class BaseWrapper to make the changes apply to KerasRegressor as well, since sklearn models are so often pickled, typically by joblib for use in parallel contexts. To me, pickling (or at least cloudpickle-ing) is part of the unofficial interface expected of sklearn estimators. This class seems like the most robust fix, and it might fix the underlying cause of issues like:

It's anecdotal, but for me using this class fixed issues using keras classifiers as part of pipelines combining custom models and off-the-shelf sklearn models, both in joblib.Parallel context and also when persisting to disk using joblib.dump. (I'm using tf backend on linux)

@sama2689
Copy link

#4274 (comment)
This works for saving the model but how do I then load it afterwards?

@IFFranciscoME
Copy link

General solution for Tensorflow models at least. Pickleable with joblib

class PickleableKerasClassifier(KerasClassifier):

    def __getstate__(self):
        state = self.__dict__
        model = state['model']
        bio = io.BytesIO()
        with h5py.File(bio) as f:
            model.save(f)
        state['model'] = bio
        return_state = deepcopy(state)
        state['model'] = model
        return return_state

    def __setstate__(self, state):
        with h5py.File(state['model']) as f:
            state['model'] = load_model(f)
        self.__dict__ = state

Thanks for sharing! I've added the case when "model" isn't yet available before "build_fn" call by fit(). I'm using scikit-optimize BayesSearchCV(). Now I can use pickle.dumps(estimator_search.best_estimator_)

import copy
import io

import dill as pickle  # Using dill instead of Python 3.7+ pickle
import h5py
import tensorflow as tf

...

class KerasClassifier(tf.keras.wrappers.scikit_learn.KerasClassifier):
    """
    TensorFlow Keras API neural network classifier.

    Workaround the tf.keras.wrappers.scikit_learn.KerasClassifier serialization
    issue using BytesIO and HDF5 in order to enable pickle dumps.

    Adapted from: https://github.com/keras-team/keras/issues/4274#issuecomment-519226139
    """

    def __getstate__(self):
        state = self.__dict__
        if "model" in state:
            model = state["model"]
            model_hdf5_bio = io.BytesIO()
            with h5py.File(model_hdf5_bio, mode="w") as file:
                model.save(file)
            state["model"] = model_hdf5_bio
            state_copy = copy.deepcopy(state)
            state["model"] = model
            return state_copy
        else:
            return state

    def __setstate__(self, state):
        if "model" in state:
            model_hdf5_bio = state["model"]
            with h5py.File(model_hdf5_bio, mode="r") as file:
                state["model"] = tf.keras.models.load_model(file)
        self.__dict__ = state

Thanks for this, i have searched for a workaround to use sklearn wrapper with tf.keras and still save with pickle. Thanks again !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests