Skip to content

Commit

Permalink
Merge c00f114 into 8e78a05
Browse files Browse the repository at this point in the history
  • Loading branch information
rbharath committed Oct 15, 2020
2 parents 8e78a05 + c00f114 commit 6113698
Showing 1 changed file with 121 additions and 0 deletions.
121 changes: 121 additions & 0 deletions deepchem/models/tests/test_reload.py
Expand Up @@ -278,6 +278,65 @@ def test_robust_multitask_classification_reload():
assert scores[classification_metric.name] > .9


# TODO: THIS DOESN'T WORK!!
#def test_robust_multitask_regressor_reload():
# """Test that RobustMultitaskRegressor can be reloaded correctly."""
# n_tasks = 10
# n_samples = 10
# n_features = 3
#
# # Generate dummy dataset
# np.random.seed(123)
# ids = np.arange(n_samples)
# X = np.random.rand(n_samples, n_features)
# y = np.random.rand(n_samples, n_tasks)
# w = np.ones((n_samples, n_tasks))
#
# dataset = dc.data.NumpyDataset(X, y, w, ids)
# regression_metric = dc.metrics.Metric(dc.metrics.mean_squared_error)
#
# model_dir = tempfile.mkdtemp()
# model = dc.models.RobustMultitaskRegressor(
# n_tasks,
# n_features,
# layer_sizes=[50],
# bypass_layer_sizes=[10],
# dropouts=[0.],
# learning_rate=0.003,
# weight_init_stddevs=[.1],
# batch_size=n_samples)
#
# # Fit trained model
# model.fit(dataset, nb_epoch=100)
#
# # Eval model on train
# scores = model.evaluate(dataset, [regression_metric])
# assert scores[regression_metric.name] < .1
#
# # Reload trained model
# reloaded_model = dc.models.RobustMultitaskRegressor(
# n_tasks,
# n_features,
# layer_sizes=[50],
# bypass_layer_sizes=[10],
# dropouts=[0.],
# learning_rate=0.003,
# weight_init_stddevs=[.1],
# batch_size=n_samples)
# reloaded_model.restore()
#
# # Check predictions match on random sample
# Xpred = np.random.rand(n_samples, n_features)
# predset = dc.data.NumpyDataset(Xpred)
# origpred = model.predict(predset)
# reloadpred = reloaded_model.predict(predset)
# assert np.all(origpred == reloadpred)
#
# # Eval model on train
# scores = reloaded_model.evaluate(dataset, [regression_metric])
# assert scores[regression_metric.name] < 0.1


def test_IRV_multitask_classification_reload():
"""Test IRV classifier can be reloaded."""
n_tasks = 5
Expand Down Expand Up @@ -398,6 +457,68 @@ def test_progressive_classification_reload():
assert scores[classification_metric.name] > .9


def test_progressivemultitaskregressor_reload():
"""Test that ProgressiveMultitaskRegressor can be reloaded correctly."""
n_samples = 10
n_features = 3
n_tasks = 1

# Generate dummy dataset
np.random.seed(123)
ids = np.arange(n_samples)
X = np.random.rand(n_samples, n_features)
y = np.random.rand(n_samples, n_tasks)
w = np.ones((n_samples, n_tasks))

dataset = dc.data.NumpyDataset(X, y, w, ids)
regression_metric = dc.metrics.Metric(dc.metrics.mean_squared_error)

model_dir = tempfile.mkdtemp()
model = dc.models.ProgressiveMultitaskRegressor(
n_tasks,
n_features,
layer_sizes=[50],
bypass_layer_sizes=[10],
dropouts=[0.],
learning_rate=0.001,
weight_init_stddevs=[.1],
alpha_init_stddevs=[.02],
batch_size=n_samples,
model_dir=model_dir)

# Fit trained model
model.fit(dataset, nb_epoch=100)

# Eval model on train
scores = model.evaluate(dataset, [regression_metric])
assert scores[regression_metric.name] < .1

# Reload trained model
reloaded_model = dc.models.ProgressiveMultitaskRegressor(
n_tasks,
n_features,
layer_sizes=[50],
bypass_layer_sizes=[10],
dropouts=[0.],
learning_rate=0.001,
weight_init_stddevs=[.1],
alpha_init_stddevs=[.02],
batch_size=n_samples,
model_dir=model_dir)
reloaded_model.restore()

# Check predictions match on random sample
Xpred = np.random.rand(n_samples, n_features)
predset = dc.data.NumpyDataset(Xpred)
origpred = model.predict(predset)
reloadpred = reloaded_model.predict(predset)
assert np.all(origpred == reloadpred)

# Eval model on train
scores = reloaded_model.evaluate(dataset, [regression_metric])
assert scores[regression_metric.name] < 0.1


## TODO: THIS IS FAILING!
#def test_DAG_regression_reload():
# """Test DAG regressor reloads."""
Expand Down

0 comments on commit 6113698

Please sign in to comment.