Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF2 porting: Enable early stopping + model save and load #739

Merged
merged 37 commits into from Jul 2, 2020

Conversation

jimthompson5802
Copy link
Collaborator

Code Pull Requests

Re-introduced early stopping option. From what I can tell there is no existing unit test for early stopping; I added early such a test. This test used two values (3 and 5) for early_stop option and confirmed training stopped per the early stop specification.

pytest -v test_model_training_options.py
======================================================= test session starts ========================================================
platform linux -- Python 3.6.9, pytest-5.4.3, py-1.8.1, pluggy-0.13.1 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /opt/project
plugins: pycharm-0.6.0, typeguard-2.8.0
collected 2 items

test_model_training_options.py::test_early_stopping[3] PASSED                                                                [ 50%]
test_model_training_options.py::test_early_stopping[5] PASSED                                                                [100%]

========================================================= warnings summary =========================================================
/usr/local/lib/python3.6/dist-packages/tensorflow/python/pywrap_tensorflow_internal.py:15
  /usr/local/lib/python3.6/dist-packages/tensorflow/python/pywrap_tensorflow_internal.py:15: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
    import imp

-- Docs: https://docs.pytest.org/en/latest/warnings.html
================================================== 2 passed, 1 warning in 12.85s ===================================================

The test_model_training_options.py can be used as a foundation to other unit tests around model training options, e.g. save or not save various logs, etc.

This PR enabled only the early stopping test. Nothing else was reenabled.

If this looks OK, I can start re-enabling other model training options.

@jimthompson5802
Copy link
Collaborator Author

jimthompson5802 commented Jun 13, 2020

From what I can see, the flip side of early stopping is saving model weights:
https://github.com/jimthompson5802/ludwig/blob/38910b6876ef43669573c6af3eb6004e66c82ad7/ludwig/models/model.py#L673-L682

In researching this aspect, I believe this is relevant to Ludwig because Ludwig's implementation is built on subclassing model and layers:
https://www.tensorflow.org/guide/keras/save_and_serialize#custom_objects

Is this a correct interpretation?

@w4nderlust
Copy link
Collaborator

w4nderlust commented Jun 14, 2020

This looks good.

Regarding your question, I believe the most relevant part is: https://www.tensorflow.org/guide/keras/save_and_serialize#apis_for_saving_weights_to_disk_and_loading_them_back

What I'm imagining is also an improvement over the previous implementation in a couple ways:

  1. we can keep in memory the model that performed the best on validation, and have an optionaò paramter in train that specifies if the returned model is the last one of the best performing one
  2. before, the logic was: if there's validation, save a model every time the validation performance improved (if not skip_save_model), if no validation is specified, save every epoch (if not skip_save_model). Maybe we can improve over this and allow the user to specify the cadence of saving, like every K epochs (that is true also with validation, we would keep in memory the weights of the best model so far and save them only after K epochs from the last save have passed, saving every K epochs the most performant weights at that moment in time). The reason for this is that saving the model on disk is expensive and slows down training substantially.
  3. another substantial improvement would be to spawn a separate thread that saves the model. Maybe there is a syncronized queue of model to be saved that the main training thread adds to and another thread that saves the model in the synchronized queue. The reason is that, as I said before, saving a model is a time consuming task and we can decouple it from the main training loop for substantially improved speed

Let me kno what you think about these 3 points.

@jimthompson5802
Copy link
Collaborator Author

All three suggestions make sense.

Do you want me to continue with the work to save model weights on this PR or open a new PR?

@w4nderlust
Copy link
Collaborator

Let me add another explanation of a piece of logic that maybe is not super obvious: progress.
By default Ludwig uses a progress tracking object and saves the model every epoch with a _progress suffix. The reason is that if you do that you can resume training exactly where you left if you want, while resuming from the best validation model so far is not the same thing because the best validation model may be the model 5 epochs before stopping training for instance.
So you should keep thing in mind when looking into validation too.
This by default is true in v0.2 but I was thinking about making it false by default as it takes a lot of time to save the model every epoch and in most cases you don't actually need to do it.

@jimthompson5802
Copy link
Collaborator Author

This reenables the functionality for saving weights during training when there is an improvement and at the end. This just enables the current TF1 behavior. I wanted to establish a working baseline and create some unit tests before working on the improvements. These are the new unit tests:

  • test_early_stopping test early stopping
  • test_model_progress_save tests saving weights during training and at the end
  • test_model_save_resume test saving weights and resuming training
pytest -v test_model_training_options.py
========================================= test session starts =========================================
platform linux -- Python 3.6.9, pytest-5.4.3, py-1.8.1, pluggy-0.13.1 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /opt/project
plugins: pycharm-0.6.0, typeguard-2.8.0
collected 7 items

test_model_training_options.py::test_early_stopping[3] PASSED                                   [ 14%]
test_model_training_options.py::test_early_stopping[5] PASSED                                   [ 28%]
test_model_training_options.py::test_model_progress_save[False-False] PASSED                    [ 42%]
test_model_training_options.py::test_model_progress_save[False-True] PASSED                     [ 57%]
test_model_training_options.py::test_model_progress_save[True-False] PASSED                     [ 71%]
test_model_training_options.py::test_model_progress_save[True-True] PASSED                      [ 85%]
test_model_training_options.py::test_model_save_resume PASSED                                   [100%]

========================================== warnings summary ===========================================
/usr/local/lib/python3.6/dist-packages/tensorflow/python/pywrap_tensorflow_internal.py:15
  /usr/local/lib/python3.6/dist-packages/tensorflow/python/pywrap_tensorflow_internal.py:15: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
    import imp

-- Docs: https://docs.pytest.org/en/latest/warnings.html
==================================== 7 passed, 1 warning in 33.40s ====================================
root@44c51e57dc88:/opt/project/tests/integration_tests#

If this looks like a good starting point, I'll start working on these changes:

we can keep in memory the model that performed the best on validation, and have an optionaò paramter in train that specifies if the returned model is the last one of the best performing one

before, the logic was: if there's validation, save a model every time the validation performance improved (if not skip_save_model), if no validation is specified, save every epoch (if not skip_save_model). Maybe we can improve over this and allow the user to specify the cadence of saving, like every K epochs (that is true also with validation, we would keep in memory the weights of the best model so far and save them only after K epochs from the last save have passed, saving every K epochs the most performant weights at that moment in time). The reason for this is that saving the model on disk is expensive and slows down training substantially.

another substantial improvement would be to spawn a separate thread that saves the model. Maybe there is a syncronized queue of model to be saved that the main training thread adds to and another thread that saves the model in the synchronized queue. The reason is that, as I said before, saving a model is a time consuming task and we can decouple it from the main training loop for substantially improved speed


y_pred = np.load(os.path.join(exp_dir_name, 'y_predictions.npy'))

mse = mean_squared_error(y_pred, generated_data.test_df['y'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what we can do here is that after the first full experiment we load the numpy predictions, and after the second experiment with resume we load the numpy predictions and then we assert that they are the same with np.isclose(first_preds, second_preds)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the recommended change in the test. Here is the commit 5a33a45

I think either restore may not be saving the weights or the resume is not loading the weights correctly. The last epoch on the first full_experiment looks like this

Epoch 28
Training: 100%|██████████| 3/3 [00:00<00:00, 65.62it/s]
Evaluation train: 100%|██████████| 3/3 [00:00<00:00, 110.61it/s]
Evaluation vali : 100%|██████████| 1/1 [00:00<00:00, 117.43it/s]
Evaluation test : 100%|██████████| 1/1 [00:00<00:00, 116.25it/s]
Took 0.1004s
Took 0.1004s
╒═══════╤═════════╤═════════╤══════════════════════╤═══════════════════════╤════════╕
│ y     │    loss │   error │   mean_squared_error │   mean_absolute_error │     r2 │
╞═══════╪═════════╪═════════╪══════════════════════╪═══════════════════════╪════════╡
│ train │ 20.6874 │ -3.7530 │              20.6874 │                3.8724 │ 0.9998 │
├───────┼─────────┼─────────┼──────────────────────┼───────────────────────┼────────┤
│ vali  │ 24.5326 │ -4.2580 │              24.5326 │                4.2998 │ 0.9997 │
├───────┼─────────┼─────────┼──────────────────────┼───────────────────────┼────────┤
│ test  │ 25.1545 │ -4.3416 │              25.1545 │                4.3740 │ 0.9997 │
╘═══════╧═════════╧═════════╧══════════════════════╧═══════════════════════╧════════╛

On the second full_experiment with the resume, the first epoch report is epoch 28, which I think makes sense but the values don't look correct

Resuming training of model: /tmp/pytest-of-root/pytest-0/test_model_save_resume0/results/experiment_run/model
Resuming training of model: /tmp/pytest-of-root/pytest-0/test_model_save_resume0/results/experiment_run/model

Epoch 28

Epoch 28
Training: 100%|██████████| 3/3 [00:00<00:00, 40.29it/s]
Evaluation train: 100%|██████████| 3/3 [00:00<00:00, 110.75it/s]
Evaluation vali : 100%|██████████| 1/1 [00:00<00:00, 115.91it/s]
Evaluation test : 100%|██████████| 1/1 [00:00<00:00, 118.83it/s]
Took 0.1300s
Took 0.1300s
╒═══════╤══════════╤═════════════╤══════════════════════╤═══════════════════════╤═════════╕
│ y     │     loss │       error │   mean_squared_error │   mean_absolute_error │      r2 │
╞═══════╪══════════╪═════════════╪══════════════════════╪═══════════════════════╪═════════╡
│ train │ 458.7538 │ 287591.8750 │             460.9559 │           287591.8438 │ -2.3807 │
├───────┼──────────┼─────────────┼──────────────────────┼───────────────────────┼─────────┤
│ vali  │ 514.6791 │ 339206.7188 │             514.6791 │           339206.7188 │ -3.0783 │
├───────┼──────────┼─────────────┼──────────────────────┼───────────────────────┼─────────┤
│ test  │ 495.4767 │ 311607.8438 │             495.4767 │           311607.8438 │ -3.2109 │
╘═══════╧══════════╧═════════════╧══════════════════════╧═══════════════════════╧═════════╛

@w4nderlust
Copy link
Collaborator

Great work so far! It's also great that for loading and saving weights it is as simple as 1 function call on the ecd object :)

@jimthompson5802
Copy link
Collaborator Author

re:

Great work so far! It's also great that for loading and saving weights it is as simple as 1 function call on the ecd object :)

Actually, I was thinking this was too easy of a change. This means I must have missed something. :-)

Thank you for suggestion on how to test model reloading. Once I confirm loading is working I'll start on the new functions.

re:

we can keep in memory the model that performed the best on validation,

Since saving weights only involves the ecd object, I'm thinking that way to implement this is to create a new attribute in ludwig.models.model.Model like self.best_ecd, e.g.,

self.ecd = ECD(input_features, combiner, output_features)
self.best_ecd = None

and when we want to keep "the model that performed the best on validation" do

self.best_ecd = self.ecd.deepcopy()

Any thoughts on this approach?

@w4nderlust
Copy link
Collaborator

Any thoughts on this approach?

Sounds good to me. my only doubt is that we probably don't need to assign it to self because that ecd copy objectwill be needed only inside the train function and at the end of it it will be either returned or discarded (depending on the logic), so probably we can get away without assigning it to self.

@jimthompson5802
Copy link
Collaborator Author

@w4nderlust
Copy link
Collaborator

Got it, so it looks like it doesn't work.
In order to debug this I would

  1. put a breakpoint before obtaining the second predictions and make sure that the data is the same and in the same order
  2. look at the outputs of the first few datapoints of the first batch manually, this will inform if the predictions are close or entirely off.
  3. if they are entirely off, i would look at the weights of any layer in the model, if they are different it means the loading didn't work, if they are the same then the loading worked and the error is in the inference.

@jimthompson5802
Copy link
Collaborator Author

Yeah, like I wrote earlier...this seemed too easy. :-) No conclusions yet but I have some observations to pass along.

When I look in the debugger after loading the weights, it appears that the lists for weights are empty even the though load_weights() method appears to work.

I noticed this from the api documentation

When graph building, restore ops are run automatically as soon as the network is built (on first call for user-defined classes inheriting from Model, immediately if it is already built).

Given that we have user defined classes, it sounds like "a first call" is needed.

The following example may illustrate this observation. I create a custom Model class, define couple of Dense layers. After "running" one input tensor through the custom model, I save the weights and then reload the weights into a second instance of the custom model.

import numpy as np

import tensorflow as tf
from tensorflow import keras

class MyModel(keras.models.Model):
    def __init__(self):
        super(MyModel, self).__init__()

        self.t1 = keras.layers.Dense(5,
                                     kernel_initializer=keras.initializers.GlorotNormal(
                                         seed=13))
        self.t2 = keras.layers.Dense(5,
                                     kernel_initializer=keras.initializers.GlorotNormal(
                                         seed=15))

    def call(self, inputs, training=None, mask=None):
        return self.t2(self.t1(inputs))

# define input tensor
in1 = tf.convert_to_tensor(np.array([[1, 2, 3], [4, 5, 6]]))

# define first model
my_model = MyModel()
t2 = my_model(in1)

# save model weights
my_model.save_weights('./my_model_weights')

# load weights into a new model instance
my_model2 = MyModel()
my_model2.load_weights('./my_model_weights')

t3 = my_model2(in1)  #<=== comment out to generate error#################

# compare weights from the two models
print("# weights my_model", len(my_model.get_weights()))
print("# weights my_model2", len(my_model2.get_weights()))

weights = zip(my_model.get_weights(), my_model2.get_weights())

print("weights matches")
for w in weights:
    print(np.all(np.isclose(w[0], w[1])))

If this line t3 = my_model2(in1) is executed, then the saved weights are "available" in the second model instance.

# weights my_model 4
# weights my_model2 4
weights matches
True
True
True
True

Process finished with exit code 0

OTOH, if I comment out t3 = my_model2(in1), then weights appear not to to be available in the second custom model.

# weights my_model 4
# weights my_model2 0
weights matches
WARNING:tensorflow:Unresolved object in checkpoint: (root).t1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).t1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).t2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).t2.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

Process finished with exit code 0

Anyway, I'm going pursue this avenue to see to see how it could resolve the issue.

@w4nderlust
Copy link
Collaborator

I may be wrong, but from your example my guess is that the weights are initialized lazily the first time you execute a call(). But it's weird that you can save them and reload them before then... not really sure what's going on :)
I'll spend some time reading the docs about saving.

@jimthompson5802
Copy link
Collaborator Author

Still working on the model training resume function. In looking over TF2's docs, I found this discussion: https://www.tensorflow.org/guide/checkpoint. I think this is the kind of functionality we are looking for. I'm looking to see how I can adapt this to Ludwig.

@w4nderlust
Copy link
Collaborator

Still working on the model training resume function. In looking over TF2's docs, I found this discussion: https://www.tensorflow.org/guide/checkpoint. I think this is the kind of functionality we are looking for. I'm looking to see how I can adapt this to Ludwig.

Checkpointing saves every k steps so acumulates the model at all steps, which is not really what we want (although n that guide they also talk about restoring). This one should be more relevant I guess: https://www.tensorflow.org/guide/keras/save_and_serialize

@jimthompson5802
Copy link
Collaborator Author

re:

This one should be more relevant I guess: https://www.tensorflow.org/guide/keras/save_and_serialize

Actually, this was the first approach I tried. When model.save() was executed, this error occurred:

E         ValueError: call() should not modify its Python input arguments. Check if it modifies any lists or dicts passed as arguments. Modifying a copy is allowed.

This is how save_weights() method came into the picture.

Anyway, I'm still working on understanding how weight restoration operates.

@w4nderlust
Copy link
Collaborator

Where does that error come from? We could try to dive deeper and understand its origin.

@jimthompson5802
Copy link
Collaborator Author

OK, I'll take another look at the error. If interested, here is the full error stack trace

../../tests/integration_tests/test_model_training_options.py:186: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../ludwig/experiment.py:349: in full_experiment
    **kwargs
../../ludwig/experiment.py:124: in experiment
    **kwargs
../../ludwig/train.py:355: in full_train
    debug=debug
../../ludwig/train.py:521: in train
    **model_definition['training']
../../ludwig/models/model.py:668: in train
    skip_save_model
../../ludwig/models/model.py:958: in check_progress_on_validation
    self.save_weights(model_weights_path)
../../ludwig/models/model.py:1115: in save_weights
    self.ecd.save(save_path)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py:1052: in save
    signatures, options)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py:138: in save_model
    signatures, options)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save.py:78: in save
    save_lib.save(model, filepath, signatures, options)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py:951: in save
    obj, export_dir, signatures, options, meta_graph_def)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py:1008: in _build_meta_graph
    checkpoint_graph_view)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_serialization.py:75: in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py:143: in list_functions
    self._serialization_cache)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:1656: in _list_functions_for_serialization
    Model, self)._list_functions_for_serialization(serialization_cache)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:2750: in _list_functions_for_serialization
    .list_functions_for_serialization(serialization_cache))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py:87: in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py:77: in functions_to_serialize
    serialization_cache).functions_to_serialize)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py:92: in _get_serialized_attributes
    serialization_cache)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py:53: in _get_serialized_attributes_internal
    serialization_cache))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py:101: in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py:153: in wrap_layer_functions
    original_fns = _replace_child_layer_functions(layer, serialization_cache)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py:272: in _replace_child_layer_functions
    serialization_cache).functions)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py:92: in _get_serialized_attributes
    serialization_cache)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py:53: in _get_serialized_attributes_internal
    serialization_cache))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py:101: in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py:163: in wrap_layer_functions
    '{}_layer_call_and_return_conditional_losses'.format(layer.name))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py:503: in add_function
    self.add_trace(*self._input_signature)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py:418: in add_trace
    trace_with_training(True)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py:416: in trace_with_training
    fn.get_concrete_function(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py:547: in get_concrete_function
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:959: in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:865: in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:506: in _initialize
    *args, **kwds))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2446: in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2777: in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2667: in _create_graph_function
    capture_by_value=self._capture_by_value),
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:988: in func_graph_from_py_func
    check_mutation(func_args_before, func_args, original_func)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

n1 = ((({'combiner_output': <tf.Tensor 'input_1_1_combiner_output:0' shape=(None, 1) dtype=float32>}, {'y': <tf.Tensor 'input_1_2_y:0' shape=(None, 64) dtype=float32>}), <tf.Tensor 'input_2:0' shape=(None,) dtype=int32>), True, None)
n2 = ((({'combiner_output': <tf.Tensor 'input_1_1_combiner_output:0' shape=(None, 1) dtype=float32>}, {'y': <tf.Tensor 'fc_...activation_9/Relu:0' shape=(None, 64) dtype=float32>}), <tf.Tensor 'input_2:0' shape=(None,) dtype=int32>), True, None)
func = <bound method OutputFeature.call of <ludwig.features.numerical_feature.NumericalOutputFeature object at 0x7f7ed6e0a9b0>>

    def check_mutation(n1, n2, func):
      """Check if two list of arguments are exactly the same."""
      func_name = getattr(func, "__name__", func)
    
      errmsg = ("{}() should not modify its Python input arguments."
                " Check if it modifies any lists or dicts passed as"
                " arguments. Modifying a copy is allowed.".format(func_name))
      try:
        # TODO(mdan): Compare more robustly so that argument names can be reported.
        nest.assert_same_structure(n1, n2, expand_composites=True)
      except ValueError:
        raise ValueError(errmsg)
    
      for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True),
                            nest.flatten(n2, expand_composites=True)):
        if arg1 is not arg2:
>         raise ValueError(errmsg)
E         ValueError: call() should not modify its Python input arguments. Check if it modifies any lists or dicts passed as arguments. Modifying a copy is allowed.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:1070: ValueError

@w4nderlust
Copy link
Collaborator

w4nderlust commented Jun 18, 2020

So I don't really know what n1 adn n2 are, but this 'y': <tf.Tensor 'fc_...activation_9/Relu:0' shape=(None, 64) dtype=float32>}) is suspicious. I guess those are two different calls to the same function, and the function that takes combiner outputs and y is likely somewhere in here: https://github.com/uber/ludwig/blob/tf2_porting/ludwig/features/base_feature.py#L177-L202

Or maybe tensorflow internally uses names of variables in a way we don't really know and requires you don't change the variables a call function receives as inputs (which is what the error message seems to suggest). In that case, probably the responsible is this: https://github.com/uber/ludwig/blob/tf2_porting/ludwig/features/base_feature.py#L170-L173
I would try replace it withsomething like:

        if isinstance(inputs, tuple):
            local_inputs,, target = inputs
        else:
           local_inputs = inputs
            target = None

And refactor the remaining code to use local_inputs or any other better name.

@jimthompson5802
Copy link
Collaborator Author

Or maybe tensorflow internally uses names of variables in a way we don't really know and requires you don't change the variables a call function receives as inputs (which is what the error message seems to suggest). In that case, probably the responsible is this: https://github.com/uber/ludwig/blob/tf2_porting/ludwig/features/base_feature.py#L170-L173
I would try replace it withsomething like:

At first blush this fits the error message. I'll make your recommended changes and test.

@jimthompson5802
Copy link
Collaborator Author

Before turning in, I wanted to provide an update.
I made the recommended change.

        # account for output feature target
        if isinstance(inputs, tuple):
            local_inputs, target = inputs
        else:
            local_inputs = inputs
            target = None

        combiner_outputs, other_output_hidden = local_inputs

Still encountering the same error.

Digging deeper, I noticed prepare_decoder_inputs() method in the OutputFeature class is invoked as part of the wrapping OutputFeature.call().
https://github.com/jimthompson5802/ludwig/blob/91bc289ceee24916533e81e56b5497413b4a95fe/ludwig/features/base_feature.py#L353-L394. Line 393 is

other_output_features[self.feature_name] = feature_hidden

other_output_features is a parameter to the prepare_decoder_inputs() method. This is an example of an input argument being modified.

When I pick this up tomorrow, I'll see if I can redesign this part of the code so that it's input parameter is not modified.

@w4nderlust
Copy link
Collaborator

I added the test_model_save_reload_API test which enables us to test in a fine grained way saving and loading of models. Some feature types are commented out, those are the feature types for which it's not currently working. let me know if the way it works is clear. It's not super polished (in particular directories, data creation etc can be improved) but it's a starting point.

@jimthompson5802
Copy link
Collaborator Author

jimthompson5802 commented Jun 30, 2020

I added the test_model_save_reload_API test which enables us to test in a fine grained way saving and loading of models.

I pulled the new unit test and it works for me.

I'm still working on the issues re: sequence model save/restore procedure.

@w4nderlust
Copy link
Collaborator

That test can help you with the sequence input and output feature, if you uncomment them from the beginning of the test you'll see that if they are present it doesn't pass (and the same for all other commented features).

@w4nderlust
Copy link
Collaborator

Updates on my side: using the newly added test I solved most of the problems with input and output features. The only input feature that is untested is the timeseries feature (because it still need to be ported) and the only two output features that are not covered yet are sequence and text. All other feature types for fine. Guess this is good complementary work with yours as you were focusing on the sequence output. I also discovered some minor bugs in pre and post processing on some features and solved them along the way, which is good :)

@w4nderlust
Copy link
Collaborator

I ported also the timeseries feature and added them to the test. They work fine. Now the only missing features are sequence and text output features.

@jimthompson5802
Copy link
Collaborator Author

Great news on the other features.

re: model weights save/restore...

I-d probably look into it by creating a sall reproducible failure of TFA saving weights and post it on their issue page, but hopefully everything works and we just have to figure out how to make it work.

At this point, I've exhausted all the possibilities. So I'm going to take your advice and submit an issue with the TFA project. The example will be a custom model and layers with train_step and predict_step. I'm planning to submit two examples using this custom model structure

  • One example a simple regress that shows the expected behavior
  • Second example is the one with "generator" decoder showing the issue.

I should have the examples ready in the next day or two and will make the submission.

@jimthompson5802
Copy link
Collaborator Author

Commit eedc838 is the fix to the error that occurs when making a prediction with the Generator Decoder after restoring weights. We will have the issue of validating the restored weights value. I just wanted to get this into the baseline for future work.

@w4nderlust
Copy link
Collaborator

I should have the examples ready in the next day or two and will make the submission.

Sounds good. In the meantime I will keep on working on the restore, trying to figure out the metrics / optimizer issue.
I'm thinking that we should probably rename this PR to "TF2: save / load" and start a new one for the save / load of sequence / test features and a separate one for the resume, what do you think?

@jimthompson5802
Copy link
Collaborator Author

I'm thinking that we should probably rename this PR to "TF2: save / load" and start a new one for the save / load of sequence / test features and a separate one for the resume, what do you think?

OK by me. Though I'll point out that the PR started to re-enable the early stopping function. Then the save/load weights was layered on the PR.

One name could be TF2 porting: Enable early stopping and model weight save and load

Or if you prefer, the rename is TF2 porting: Model weights save and load

@w4nderlust w4nderlust changed the title TF2 porting: Enable early stopping TF2 porting: Enable early stopping + model save and load Jul 2, 2020
@w4nderlust w4nderlust merged commit 552f19a into ludwig-ai:tf2_porting Jul 2, 2020
@jimthompson5802 jimthompson5802 deleted the tf2_early_stopping branch July 2, 2020 00:59
@jimthompson5802
Copy link
Collaborator Author

@w4nderlust This just an update. As we discussed, I'm creating a minimal example to open an issue in TFA re: saving and loading weights when the Generator decoder is used. Actually, I'm creating two minimal examples:

  • First example demonstrates following on a simple regression model. This demonstrates expected output, i.e, matching predictions and weight comparisons.
    • Define custom layer and custom model
    • Custom model training,
    • saving weights,
    • loading weights into the restored custom model
    • comparing predictions and weights between first model and restored model
  • Second example will perform the same sequence processing with the customer layers representing Ludwig's custom encoder and decoder (Generator) layers. This will demonstrate the weights not matching.

The reasons for this posting is that I've confirmed that restoring the optimizer is key to avoiding this situation:

This is great, although that sudden jump to 18 for the loss is kinda suspicious.

Using the simple regression example, I can recreate the "sudden jump" in the training loss value. The "sudden jump" occurs if I use new optimizer when resuming training. However if I reuse the optimizer from the initial training, there is no "sudden jump".

Here are log files demonstrating the two situations:

Not reusing the optimizer...there is "sudden jump" in the training loss when resuming training

initial model training
2020-07-03 04:45:51.663124: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-07-03 04:45:51.663174: E tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: UNKNOWN ERROR (303)
2020-07-03 04:45:51.663197: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (bade7d96437c): /proc/driver/nvidia/version does not exist
2020-07-03 04:45:51.663718: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-07-03 04:45:51.670588: I tensorflow/core/platform/profile_utils/cpu_utils.cc:102] CPU Frequency: 2791415000 Hz
2020-07-03 04:45:51.671385: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fe808000b20 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-07-03 04:45:51.671429: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
training loss metric, epoch 1 1321975.0
training loss metric, epoch 2 80907.58
training loss metric, epoch 3 14119.333
training loss metric, epoch 4 2678.0925
training loss metric, epoch 5 1134.4769
training loss metric, epoch 6 570.3199
training loss metric, epoch 7 245.0745
training loss metric, epoch 8 117.75263
training loss metric, epoch 9 77.32518
training loss metric, epoch 10 56.18808
saving initial trained model weights

restore saved model weights

comparing predicitons and weights
predictions match:  True
weights match:  True

resume training
training loss metric, epoch 1 3286.9758
training loss metric, epoch 2 355.24738
training loss metric, epoch 3 45.247036
training loss metric, epoch 4 32.90129
training loss metric, epoch 5 22.105053
training loss metric, epoch 6 14.485716
training loss metric, epoch 7 12.618556
training loss metric, epoch 8 9.494252
training loss metric, epoch 9 8.483568
training loss metric, epoch 10 7.7146688
all done

Process finished with exit code 0

Reusing the optimizer from initial training...no sudden jump in training loss when resuming training

initial model training
2020-07-03 05:11:36.412566: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-07-03 05:11:36.412617: E tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: UNKNOWN ERROR (303)
2020-07-03 05:11:36.412641: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (f0e654b21256): /proc/driver/nvidia/version does not exist
2020-07-03 05:11:36.413145: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-07-03 05:11:36.420651: I tensorflow/core/platform/profile_utils/cpu_utils.cc:102] CPU Frequency: 2791415000 Hz
2020-07-03 05:11:36.421571: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f4a60000b20 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-07-03 05:11:36.421617: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
training loss metric, epoch 1 1321975.0
training loss metric, epoch 2 80907.58
training loss metric, epoch 3 14119.333
training loss metric, epoch 4 2678.0925
training loss metric, epoch 5 1134.4769
training loss metric, epoch 6 570.3199
training loss metric, epoch 7 245.0745
training loss metric, epoch 8 117.75263
training loss metric, epoch 9 77.32518
training loss metric, epoch 10 56.18808
saving initial trained model weights

restore saved model weights

comparing predicitons and weights
predictions match:  True
weights match:  True

resume training
training loss metric, epoch 1 48.457054
training loss metric, epoch 2 42.57572
training loss metric, epoch 3 35.806385
training loss metric, epoch 4 32.360104
training loss metric, epoch 5 28.593756
training loss metric, epoch 6 26.034916
training loss metric, epoch 7 24.362673
training loss metric, epoch 8 20.92635
training loss metric, epoch 9 20.282171
training loss metric, epoch 10 18.658442
all done

Process finished with exit code 0

Here is the code if you're interested in seeing the effect of optimizer. See line 176. Comment or uncomment line as desired.
mwe_model_save_load_works.txt

@w4nderlust
Copy link
Collaborator

Funny, I just psoted this:
tensorflow/tensorflow#41053
whis is basically the same thing :)

It shows how by saving and reloading a stateless optimizer, thigs work fine, while saving and reloading a stateful optimizer things break.

My take is that we should separate these 2 aspects, the resume aspect (making sure that when we resume things keep on working) and the plain save and load aspect (we only care about weights and predictions to be the same, not if the training can resume correclty). For the second one, all features work except sequene and test output features, which is what's important to fix on your side I beleive. Anyway it's great that we both confirmed the same behavior, now we know exactly where are the errors (optimizer restoration and TFA loading).

@jimthompson5802
Copy link
Collaborator Author

@w4nderlust Two topics:

  1. Could you run the unit test tests/integration_tests/test_simple_features.py. Reason for asking, this test now fails when I run it. I want to see if it fails for you.

The failure indicates, "expected keyword" during initialization of the Adam optimizer. the "unexpected keyword" is beta1. Looking at the tf.keras.optimizers.Adam documentation keywords listed are beta_1 and beta_2. In Ludwig's defaults.py the default keywords for Adam are specified as beta1 and beta2. From what I can tell, the Ludwig code (defaults.py) that specifies the default Adam keywords has not changed in a couple years.

Let me know if this test works for you.

  1. The reason I asked the first question: the minimal example for the sequence model save and load that I was working on to submit to TFA does not fail. It is possible that I accidentally fixed the issue when creating the minimal example.

To eliminate that possibility, I tried recreating the issue with the standard Ludwig unit tests. This is when I encountered the Adam optimizer issue. While not all tests were passing in test_experiment.py the last time I ran the test (2 weeks ago) about 320 out of the 348 passed. Now all 348 fail.

Looks like a recent change may have broken something.

@jimthompson5802
Copy link
Collaborator Author

I dug deeper into the Adam initialization error. I submitted PR #749 to fix.

@jimthompson5802
Copy link
Collaborator Author

After fixing the Adam initialization error, I tried the original Ludwig program where weights for the sequence model do not appear to be load correctly. the problem still exists in that version of the program.
ludwig_api_test_generator.txt

Assuming this may be related to TFA, I worked on a "minimal example" to illustrate the issue with plans to open an issue with TFA. Here is the minimal example.
mwe_model_save_load_sequence.txt
I cannot recreate the issue with the minimal example. In trying to keep it minimal I recreated by hand the key custom subclasses of tf.keras.models.Model and tf.keras.layers.Layer that Ludwig implements.

Right now I'm going back over both Ludwig's TF2 implementation and how I constructed the minimal example. I'm hoping by comparing the two, I find the difference that may help us with save/load of the sequence model.

@w4nderlust
Copy link
Collaborator

Regarding the Adam parameters, yer, those sanes were the names of the TF1 parameters, changing to the TF2 names, fixed it.

Regarding the sequence feature, i think you are on the right track. If the custom example works, then there's something in our current implementation that is different from the example that makes it not work, we should identify what it is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants