Skip to content

Commit

Permalink
Merge pull request #2258 from ncfrey/nf-saving-reloading
Browse files Browse the repository at this point in the history
Reload tests for normalizing flow models
  • Loading branch information
ncfrey committed Oct 31, 2020
2 parents 12b35a1 + 1b99686 commit 63af0c9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
9 changes: 9 additions & 0 deletions deepchem/models/normalizing_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepchem.models.keras_model import KerasModel
from deepchem.models.optimizers import Optimizer, Adam
from deepchem.utils.typing import OneOrMany
from deepchem.utils.data_utils import load_from_disk, save_to_disk

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -183,6 +184,14 @@ def create_nll(self, input: OneOrMany[tf.Tensor]) -> tf.Tensor:

return -tf.reduce_mean(self.flow.log_prob(input, training=True))

def save(self):
"""Saves model to disk using joblib."""
save_to_disk(self.model, self.get_model_filename(self.model_dir))

def reload(self):
"""Loads model from joblib file on disk."""
self.model = load_from_disk(self.get_model_filename(self.model_dir))

def _create_gradient_fn(self,
variables: Optional[List[tf.Variable]]) -> Callable:
"""Create a function that computes gradients and applies them to the model.
Expand Down
44 changes: 44 additions & 0 deletions deepchem/models/tests/test_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,50 @@ def test_robust_multitask_classification_reload():
assert scores[classification_metric.name] > .9


def test_normalizing_flow_model_reload():
"""Test that RobustMultitaskRegressor can be reloaded correctly."""
from deepchem.models.normalizing_flows import NormalizingFlow, NormalizingFlowModel
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfk = tf.keras
tfk.backend.set_floatx('float64')

model_dir = tempfile.mkdtemp()

Made = tfb.AutoregressiveNetwork(
params=2, hidden_units=[512, 512], activation='relu')

flow_layers = [tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=Made)]
# 3D Multivariate Gaussian base distribution
nf = NormalizingFlow(
base_distribution=tfd.MultivariateNormalDiag(
loc=np.zeros(2), scale_diag=np.ones(2)),
flow_layers=flow_layers)

nfm = NormalizingFlowModel(nf, model_dir=model_dir)

target_distribution = tfd.MultivariateNormalDiag(loc=np.array([1., 0.]))
dataset = dc.data.NumpyDataset(X=target_distribution.sample(96))
final = nfm.fit(dataset, nb_epoch=1)

x = np.zeros(2)
lp1 = nfm.flow.log_prob(x).numpy()

assert nfm.flow.sample().numpy().shape == (2,)

reloaded_model = NormalizingFlowModel(nf, model_dir=model_dir)
reloaded_model.restore()

# Check that reloaded model can sample from the distribution
assert reloaded_model.flow.sample().numpy().shape == (2,)

lp2 = reloaded_model.flow.log_prob(x).numpy()

# Check that density estimation is same for reloaded model
assert np.all(lp1 == lp2)


def test_robust_multitask_regressor_reload():
"""Test that RobustMultitaskRegressor can be reloaded correctly."""
n_tasks = 10
Expand Down

0 comments on commit 63af0c9

Please sign in to comment.