Skip to content

Commit

Permalink
Merge pull request #228 from lintusj1/model-saving-improvement
Browse files Browse the repository at this point in the history
Easier model saving and loading
  • Loading branch information
Jarno Lintusaari committed Sep 5, 2017
2 parents fc6eb5d + f17e1a4 commit af0a29f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 25 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Changelog
=========

dev
---

- Easier saving and loading of ElfiModel
- Renamed elfi.set_current_model to elfi.set_default_model
- Renamed elfi.get_current_model to elfi.get_default_model

0.6.1 (2017-07-21)
------------------

Expand Down
5 changes: 3 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ Below is the API for creating generative models.

.. autosummary::
elfi.new_model
elfi.get_current_model
elfi.set_current_model
elfi.load_model
elfi.get_default_model
elfi.set_default_model

.. currentmodule:: elfi.visualization.visualization

Expand Down
114 changes: 93 additions & 21 deletions elfi/model/elfi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import inspect
import logging
import os
import pickle
import re
import uuid
from functools import partial
Expand All @@ -23,55 +25,86 @@

__all__ = [
'ElfiModel', 'ComputationContext', 'NodeReference', 'Constant', 'Operation', 'RandomVariable',
'Prior', 'Simulator', 'Summary', 'Discrepancy', 'Distance', 'get_current_model',
'set_current_model', 'new_model'
'Prior', 'Simulator', 'Summary', 'Discrepancy', 'Distance', 'get_default_model',
'set_default_model', 'new_model', 'load_model'
]

logger = logging.getLogger(__name__)
_current_model = None
_default_model = None


def get_current_model():
"""Return the current default `elfi.ElfiModel` instance.
def get_default_model():
"""Return the current default ``ElfiModel`` instance.
New nodes will be added to this model by default.
"""
global _current_model
if _current_model is None:
_current_model = ElfiModel()
return _current_model
global _default_model
if _default_model is None:
_default_model = ElfiModel()
return _default_model


def set_current_model(model=None):
"""Set the current default `elfi.ElfiModel` instance.
def set_default_model(model=None):
"""Set the current default ``ElfiModel`` instance.
New nodes will be placed the given model by default.
Parameters
----------
model : ElfiModel, optional
If None, creates a new ElfiModel.
If None, creates a new ``ElfiModel``.
"""
global _current_model
global _default_model
if model is None:
model = ElfiModel()
if not isinstance(model, ElfiModel):
raise ValueError('{} is not an instance of ElfiModel'.format(ElfiModel))
_current_model = model
_default_model = model


def new_model(name=None, set_default=True):
"""Create a new ``ElfiModel`` instance.
def new_model(name=None, set_current=True):
"""Create a new ElfiModel.
In addition to making a new ElfiModel instance, this method sets the new instance as
the default for new nodes.
Parameters
----------
name : str, optional
set_current : bool, optional
Whether to set this ElfiModel as the current (default) one.
set_default : bool, optional
Whether to set the newly created model as the current model.
"""
model = ElfiModel(name=name)
if set_current:
set_current_model(model)
if set_default:
set_default_model(model)
return model


def load_model(name, prefix=None, set_default=True):
"""Load the pickled ElfiModel.
Assumes there exists a file "name.pkl" in the current directory. Also sets the loaded
model as the default model for new nodes.
Parameters
----------
name : str
Name of the model file to load (without the .pkl extension).
prefix : str
Path to directory where the model file is located, optional.
set_default : bool, optional
Set the loaded model as the default model. Default is True.
Returns
-------
ElfiModel
"""
model = ElfiModel.load(name, prefix=prefix)
if set_default:
set_default_model(model)
return model


Expand Down Expand Up @@ -353,6 +386,45 @@ def copy(self):
kopy.name = "{}_copy_{}".format(self.name, random_name())
return kopy

def save(self, prefix=None):
"""Save the current model to pickled file.
Parameters
----------
prefix : str, optional
Path to the directory under which to save the model. Default is the current working
directory.
"""
path = self.name + '.pkl'
if prefix is not None:
os.makedirs(prefix, exist_ok=True)
path = os.path.join(prefix, path)
pickle.dump(self, open(path, "wb"))

@classmethod
def load(cls, name, prefix):
"""Load the pickled ElfiModel.
Assumes there exists a file "name.pkl" in the current directory.
Parameters
----------
name : str
Name of the model file to load (without the .pkl extension).
prefix : str
Path to directory where the model file is located, optional.
Returns
-------
ElfiModel
"""
path = name + '.pkl'
if prefix is not None:
path = os.path.join(prefix, path)
return pickle.load(open(path, "rb"))

def __getitem__(self, node_name):
"""Return a new reference object for a node in the model.
Expand Down Expand Up @@ -454,7 +526,7 @@ def _determine_model(self, model, parents):
raise ValueError('Parents are from different models!')

if model is None:
model = get_current_model()
model = get_default_model()

return model

Expand Down
19 changes: 17 additions & 2 deletions tests/unit/test_elfi_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pytest

Expand Down Expand Up @@ -69,17 +71,30 @@ def test_remove_node(self, ma2):

assert 'MA2' not in ma2.observed

def test_save_load(self, ma2):
name = ma2.name
ma2.save()
ma2 = elfi.load_model(name)
os.remove(name + '.pkl')

# Same with a prefix
prefix = 'models_dir'
ma2.save(prefix)
ma2 = elfi.load_model(name, prefix)
os.remove(os.path.join(prefix, name + '.pkl'))
os.removedirs(prefix)


class TestNodeReference:
def test_name_argument(self):
# This is important because it is used when passing NodeReferences as
# InferenceMethod arguments
em.set_current_model()
em.set_default_model()
ref = em.NodeReference(name='test')
assert str(ref) == 'test'

def test_name_determination(self):
em.set_current_model()
em.set_default_model()
node = em.NodeReference()
assert node.name == 'node'

Expand Down

0 comments on commit af0a29f

Please sign in to comment.