Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF2 porting: Enable early stopping + model save and load #739

Merged
merged 37 commits into from Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
38910b6
feat: enabled early stopping and add early stopping unit test
jimthompson5802 Jun 13, 2020
6b10058
feat: enable model save and restore functions with unit test
jimthompson5802 Jun 15, 2020
522e702
Merge remote-tracking branch 'upstream/tf2_porting' into tf2_early_st…
jimthompson5802 Jun 15, 2020
d0eaf2b
Merge branch 'tf2_early_stopping' into tf2_early_stopping_model_save_…
jimthompson5802 Jun 15, 2020
9d6627c
refactor: eliminate warning pycharm warning message
jimthompson5802 Jun 15, 2020
a70d0e5
feat: add test for saving progress weights and final model
jimthompson5802 Jun 16, 2020
e45a155
feat: update restoring function
jimthompson5802 Jun 16, 2020
5a33a45
test: change assertion test for model save/resume
jimthompson5802 Jun 17, 2020
91bc289
refactor: re-enable resume_session() and restore() methods
jimthompson5802 Jun 17, 2020
acb8bed
fix: ValueError when saving model to disk
jimthompson5802 Jun 19, 2020
0264959
fix: resolve error when restoring saved model
jimthompson5802 Jun 20, 2020
037e3b7
fix: adapt test to directory for TF2 saved model file storage
jimthompson5802 Jun 20, 2020
1444c69
fix: syntax error
jimthompson5802 Jun 20, 2020
ddc7157
fix: TEMPORARY CODE FOR DEBUGGING PURPOSES - NEED TO BE REPLACED
jimthompson5802 Jun 20, 2020
48a9f41
feat: for model save/restore support dictionary of custom objects
jimthompson5802 Jun 20, 2020
0e5d1e3
test: VERSION USED FOR DEBUGGING
jimthompson5802 Jun 22, 2020
fccd4b1
refactor: change from savedmodel to save_weights approach
jimthompson5802 Jun 27, 2020
157caa3
refactor: remove hack for initializing weights
jimthompson5802 Jun 27, 2020
c0d28f2
fix: reporting metrics in wrong order when resuming model training
jimthompson5802 Jun 28, 2020
799406d
feat: initial working LudwigModel.predict() method with TF2
jimthompson5802 Jun 28, 2020
d4035fc
feat: allow specification of optimizer
jimthompson5802 Jun 28, 2020
0fdeed6
fix: restoration of saved model weights
jimthompson5802 Jun 28, 2020
182eea2
Added save and reload test using APIs
w4nderlust Jun 30, 2020
0e88675
Fix: encoder creation in binary feature
w4nderlust Jun 30, 2020
fc809f9
Expanded the test_model_save_reload_API test
w4nderlust Jun 30, 2020
da1ef01
Fix: vector fature encoder return from dict of dict to dict
w4nderlust Jul 1, 2020
dbe7c33
Fix: bag feature_data when input is a dataframe
w4nderlust Jul 1, 2020
88df360
Fix: image feature_data when input is a dataframe
w4nderlust Jul 1, 2020
8593dd9
Fix: image feature_data cleanup
w4nderlust Jul 1, 2020
64efb3a
Fix: audio feature_data when input is a dataframe
w4nderlust Jul 1, 2020
fe0245e
Fix: most input features now work in test_model_save_reload_API
w4nderlust Jul 1, 2020
8e13c08
Fix: set output feature bugs (missing import for loss in metrics, mis…
w4nderlust Jul 1, 2020
eb13293
Fix: vector output feature bugs (missed kwargs and missing call to .n…
w4nderlust Jul 1, 2020
e66e3ef
Added additional outputs (ctegory, set, vector) to test_model_save_re…
w4nderlust Jul 1, 2020
d64953a
Merge branch 'tf2_porting' into tf2_early_stopping
w4nderlust Jul 1, 2020
fbfcb84
Added timeseries inputs to test_model_save_reload_API test
w4nderlust Jul 1, 2020
eedc838
fix: IndexError exception after model weights restore - work-in-progress
jimthompson5802 Jul 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
111 changes: 55 additions & 56 deletions ludwig/models/model.py
Expand Up @@ -379,8 +379,8 @@ def train(

# ====== Setup session =======
# todo tf2: reintroduce restoring weights
# if self.weights_save_path:
# self.restore(session, self.weights_save_path)
if self.weights_save_path:
self.restore(self.weights_save_path)

# todo tf2: reintroduce tensorboard logging
# train_writer = None
Expand Down Expand Up @@ -651,52 +651,51 @@ def train(
)
)

# if should_validate:
# should_break = self.check_progress_on_validation(
# progress_tracker,
# validation_field,
# validation_metric,
# session,
# model_weights_path,
# model_hyperparameters_path,
# reduce_learning_rate_on_plateau,
# reduce_learning_rate_on_plateau_patience,
# reduce_learning_rate_on_plateau_rate,
# increase_batch_size_on_plateau_patience,
# increase_batch_size_on_plateau,
# increase_batch_size_on_plateau_max,
# increase_batch_size_on_plateau_rate,
# early_stop,
# skip_save_model
# )
# if should_break:
# break
# else:
# # there's no validation, so we save the model at each iteration
# if is_on_master():
# if not skip_save_model:
# self.save_weights(session, model_weights_path)
# self.save_hyperparameters(
# self.hyperparameters,
# model_hyperparameters_path
# )
#
# # ========== Save training progress ==========
# if is_on_master():
# if not skip_save_progress:
# self.save_weights(session, model_weights_progress_path)
# progress_tracker.save(
# os.path.join(
# save_path,
# TRAINING_PROGRESS_FILE_NAME
# )
# )
# if skip_save_model:
# self.save_hyperparameters(
# self.hyperparameters,
# model_hyperparameters_path
# )
#
if should_validate:
should_break = self.check_progress_on_validation(
progress_tracker,
validation_field,
validation_metric,
model_weights_path,
model_hyperparameters_path,
reduce_learning_rate_on_plateau,
reduce_learning_rate_on_plateau_patience,
reduce_learning_rate_on_plateau_rate,
increase_batch_size_on_plateau_patience,
increase_batch_size_on_plateau,
increase_batch_size_on_plateau_max,
increase_batch_size_on_plateau_rate,
early_stop,
skip_save_model
)
if should_break:
break
else:
# there's no validation, so we save the model at each iteration
if is_on_master():
if not skip_save_model:
self.save_weights(model_weights_path)
self.save_hyperparameters(
self.hyperparameters,
model_hyperparameters_path
)

# ========== Save training progress ==========
if is_on_master():
if not skip_save_progress:
self.save_weights(model_weights_progress_path)
progress_tracker.save(
os.path.join(
save_path,
TRAINING_PROGRESS_FILE_NAME
)
)
if skip_save_model:
self.save_hyperparameters(
self.hyperparameters,
model_hyperparameters_path
)

# if is_on_master():
# contrib_command("train_epoch_end", progress_tracker)
# logger.info('')
Expand Down Expand Up @@ -932,7 +931,7 @@ def check_progress_on_validation(
progress_tracker,
validation_field,
validation_metric,
session, model_weights_path,
model_weights_path,
model_hyperparameters_path,
reduce_learning_rate_on_plateau,
reduce_learning_rate_on_plateau_patience,
Expand All @@ -957,7 +956,7 @@ def check_progress_on_validation(
validation_field][validation_metric][-1]
if is_on_master():
if not skip_save_model:
self.save_weights(session, model_weights_path)
self.save_weights(model_weights_path)
self.save_hyperparameters(
self.hyperparameters,
model_hyperparameters_path
Expand Down Expand Up @@ -1111,10 +1110,10 @@ def collect_weights(
# return collected_tensors
pass

def save_weights(self, session, save_path):
def save_weights(self, save_path):
# todo tf2: reintroduce functionality
# self.weights_save_path = self.saver.save(session, save_path)
pass
#self.weights_save_path = self.saver.save(save_path)
self.ecd.save_weights(save_path)

def save_hyperparameters(self, hyperparameters, save_path):
# removing pretrained embeddings paths from hyperparameters
Expand Down Expand Up @@ -1159,10 +1158,11 @@ def save_savedmodel(self, save_path):
# builder.save()
pass

def restore(self, session, weights_path):
def restore(self,weights_path):
# todo tf2: reintroduce this functionality
# self.saver.restore(session, weights_path)
pass
self.ecd.load_weights(weights_path)


@staticmethod
def load(load_path, use_horovod=False):
Expand Down Expand Up @@ -1298,7 +1298,6 @@ def initialize_batcher(

def resume_session(
self,
session,
save_path,
model_weights_path,
model_weights_progress_path
Expand Down
2 changes: 1 addition & 1 deletion ludwig/predict.py
Expand Up @@ -99,7 +99,7 @@ def full_predict(
gpu_fraction,
debug
)
model.close_session()
# model.close_session() # todo tf2 code clean -up

if is_on_master():
# setup directories and file names
Expand Down
204 changes: 204 additions & 0 deletions tests/integration_tests/test_model_training_options.py
@@ -0,0 +1,204 @@
import os.path
import json
from collections import namedtuple

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

import pytest

from ludwig.experiment import full_experiment
from ludwig.predict import full_predict

GeneratedData = namedtuple('GeneratedData',
'train_df validation_df test_df')

def get_feature_definitions():
input_features = [
{'name': 'x', 'type': 'numerical'},
]
output_features = [
{'name': 'y', 'type': 'numerical', 'loss': {'type': 'mean_squared_error'},
'num_fc_layers': 5, 'fc_size': 64}
]

return input_features, output_features


@pytest.fixture
def generated_data():
# function generates simple training data that guarantee convergence
# within 30 epochs for suitable model definition
NUMBER_OBSERVATIONS = 500

# generate data
np.random.seed(43)
x = np.array(range(NUMBER_OBSERVATIONS)).reshape(-1, 1)
y = 2*x + 1 + np.random.normal(size=x.shape[0]).reshape(-1, 1)
raw_df = pd.DataFrame(np.concatenate((x, y), axis=1), columns=['x', 'y'])

# create training data
train, valid_test = train_test_split(raw_df, train_size=0.7)

# create validation and test data
validation, test = train_test_split(valid_test, train_size=0.5)

return GeneratedData(train, validation, test)

@pytest.mark.parametrize('early_stop', [3, 5])
def test_early_stopping(early_stop, generated_data, tmp_path):

input_features, output_features = get_feature_definitions()

model_definition = {
'input_features': input_features,
'output_features': output_features,
'combiner': {
'type': 'concat'
},
'training': {
'epochs': 30,
'early_stop': early_stop,
'batch_size': 16
}
}

# create sub-directory to store results
results_dir = tmp_path / 'results'
results_dir.mkdir()

# run experiment
exp_dir_name = full_experiment(
data_train_df=generated_data.train_df,
data_validation_df=generated_data.validation_df,
data_test_df=generated_data.test_df,
output_directory=str(results_dir),
model_definition=model_definition,
skip_save_processed_input=True,
skip_save_progress=True,
skip_save_unprocessed_output=True,
skip_save_model=True,
skip_save_log=True
)

# test existence of required files
train_stats_fp = os.path.join(exp_dir_name, 'training_statistics.json')
metadata_fp = os.path.join(exp_dir_name, 'description.json')
assert os.path.isfile(train_stats_fp)
assert os.path.isfile(metadata_fp)

# retrieve results so we can validate early stopping
with open(train_stats_fp,'r') as f:
train_stats = json.load(f)
with open(metadata_fp, 'r') as f:
metadata = json.load(f)

# get early stopping value
early_stop_value = metadata['model_definition']['training']['early_stop']

# retrieve validation losses
vald_losses = np.array(train_stats['validation']['combined']['loss'])
last_epoch = vald_losses.shape[0]
best_epoch = np.argmin(vald_losses)

# confirm early stopping
assert (last_epoch - best_epoch - 1) == early_stop_value

@pytest.mark.parametrize('skip_save_progress', [False, True])
@pytest.mark.parametrize('skip_save_model', [False, True])
def test_model_progress_save(
skip_save_progress,
skip_save_model,
generated_data,
tmp_path
):

input_features, output_features = get_feature_definitions()

model_definition = {
'input_features': input_features,
'output_features': output_features,
'combiner': {'type': 'concat', 'fc_size': 14},
'training': {'epochs': 10}
}

# create sub-directory to store results
results_dir = tmp_path / 'results'
results_dir.mkdir()

# run experiment
exp_dir_name = full_experiment(
data_train_df=generated_data.train_df,
data_validation_df=generated_data.validation_df,
data_test_df=generated_data.test_df,
output_directory=str(results_dir),
model_definition=model_definition,
skip_save_processed_input=True,
skip_save_progress=skip_save_progress,
skip_save_unprocessed_output=True,
skip_save_model=skip_save_model,
skip_save_log=True
)

#========== Check for required result data sets =============
if skip_save_model:
assert not os.path.isfile(
os.path.join(exp_dir_name, 'model', 'model_weights.index')
)
else:
assert os.path.isfile(
os.path.join(exp_dir_name, 'model', 'model_weights.index')
)

if skip_save_progress:
assert not os.path.isfile(
os.path.join(exp_dir_name, 'model', 'model_weights_progress.index')
)
else:
assert os.path.isfile(
os.path.join(exp_dir_name, 'model', 'model_weights_progress.index')
)


# work-in-progress
def test_model_save_resume(generated_data, tmp_path):

input_features, output_features = get_feature_definitions()
model_definition = {
'input_features': input_features,
'output_features': output_features,
'combiner': {'type': 'concat', 'fc_size': 14},
'training': {'epochs': 30, 'early_stop': 5}
}

# create sub-directory to store results
results_dir = tmp_path / 'results'
results_dir.mkdir()

exp_dir_name = full_experiment(
model_definition,
data_train_df=generated_data.train_df,
data_validation_df=generated_data.validation_df,
data_test_df=generated_data.test_df,
output_directory=results_dir
)

full_experiment(
model_definition,
data_train_df=generated_data.train_df,
model_resume_path=exp_dir_name
)

test_fp = os.path.join(str(tmp_path), 'data_to_predict.csv')
generated_data.test_df.to_csv(
test_fp,
index=False
)

full_predict(os.path.join(exp_dir_name, 'model'), data_csv=test_fp)

y_pred = np.load(os.path.join(exp_dir_name, 'y_predictions.npy'))

mse = mean_squared_error(y_pred, generated_data.test_df['y'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what we can do here is that after the first full experiment we load the numpy predictions, and after the second experiment with resume we load the numpy predictions and then we assert that they are the same with np.isclose(first_preds, second_preds)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the recommended change in the test. Here is the commit 5a33a45

I think either restore may not be saving the weights or the resume is not loading the weights correctly. The last epoch on the first full_experiment looks like this

Epoch 28
Training: 100%|██████████| 3/3 [00:00<00:00, 65.62it/s]
Evaluation train: 100%|██████████| 3/3 [00:00<00:00, 110.61it/s]
Evaluation vali : 100%|██████████| 1/1 [00:00<00:00, 117.43it/s]
Evaluation test : 100%|██████████| 1/1 [00:00<00:00, 116.25it/s]
Took 0.1004s
Took 0.1004s
╒═══════╤═════════╤═════════╤══════════════════════╤═══════════════════════╤════════╕
│ y     │    loss │   error │   mean_squared_error │   mean_absolute_error │     r2 │
╞═══════╪═════════╪═════════╪══════════════════════╪═══════════════════════╪════════╡
│ train │ 20.6874 │ -3.7530 │              20.6874 │                3.8724 │ 0.9998 │
├───────┼─────────┼─────────┼──────────────────────┼───────────────────────┼────────┤
│ vali  │ 24.5326 │ -4.2580 │              24.5326 │                4.2998 │ 0.9997 │
├───────┼─────────┼─────────┼──────────────────────┼───────────────────────┼────────┤
│ test  │ 25.1545 │ -4.3416 │              25.1545 │                4.3740 │ 0.9997 │
╘═══════╧═════════╧═════════╧══════════════════════╧═══════════════════════╧════════╛

On the second full_experiment with the resume, the first epoch report is epoch 28, which I think makes sense but the values don't look correct

Resuming training of model: /tmp/pytest-of-root/pytest-0/test_model_save_resume0/results/experiment_run/model
Resuming training of model: /tmp/pytest-of-root/pytest-0/test_model_save_resume0/results/experiment_run/model

Epoch 28

Epoch 28
Training: 100%|██████████| 3/3 [00:00<00:00, 40.29it/s]
Evaluation train: 100%|██████████| 3/3 [00:00<00:00, 110.75it/s]
Evaluation vali : 100%|██████████| 1/1 [00:00<00:00, 115.91it/s]
Evaluation test : 100%|██████████| 1/1 [00:00<00:00, 118.83it/s]
Took 0.1300s
Took 0.1300s
╒═══════╤══════════╤═════════════╤══════════════════════╤═══════════════════════╤═════════╕
│ y     │     loss │       error │   mean_squared_error │   mean_absolute_error │      r2 │
╞═══════╪══════════╪═════════════╪══════════════════════╪═══════════════════════╪═════════╡
│ train │ 458.7538 │ 287591.8750 │             460.9559 │           287591.8438 │ -2.3807 │
├───────┼──────────┼─────────────┼──────────────────────┼───────────────────────┼─────────┤
│ vali  │ 514.6791 │ 339206.7188 │             514.6791 │           339206.7188 │ -3.0783 │
├───────┼──────────┼─────────────┼──────────────────────┼───────────────────────┼─────────┤
│ test  │ 495.4767 │ 311607.8438 │             495.4767 │           311607.8438 │ -3.2109 │
╘═══════╧══════════╧═════════════╧══════════════════════╧═══════════════════════╧═════════╛