diff --git a/lightfm/lightfm.py b/lightfm/lightfm.py index f71404a1..aa1ad11f 100644 --- a/lightfm/lightfm.py +++ b/lightfm/lightfm.py @@ -20,10 +20,20 @@ CYTHON_DTYPE = np.float32 -model_weights = {'user_embeddings', 'user_biases', 'item_embeddings', 'item_biases', - 'item_bias_momentum', 'item_bias_gradients', 'item_embedding_momentum', - 'item_embedding_gradients', 'user_bias_momentum', 'user_bias_gradients', - 'user_embedding_momentum', 'user_embedding_gradients'} +model_weights = { + "user_embeddings", + "user_biases", + "item_embeddings", + "item_biases", + "item_bias_momentum", + "item_bias_gradients", + "item_embedding_momentum", + "item_embedding_gradients", + "user_bias_momentum", + "user_bias_gradients", + "user_embedding_momentum", + "user_embedding_gradients", +} class LightFM(object): @@ -513,7 +523,9 @@ def load(self, path): for value in [x for x in numpy_model if x in model_weights]: setattr(self, value, numpy_model[value]) - self.set_params(**{k: v for k, v in numpy_model.items() if k not in model_weights}) + self.set_params( + **{k: v for k, v in numpy_model.items() if k not in model_weights} + ) def fit( self,