Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF2 porting: initial work #639

Merged
merged 18 commits into from Feb 27, 2020

Conversation

jimthompson5802
Copy link
Collaborator

@w4nderlust This is a test of sending results of my work. Initially I tried pushing my changes directly to uber/ludwig but it failed because I did not have write permission to uber/ludwig.

I'm now trying to submit my changes through a PR for branch tf2_porting branch in my forkjimthompson5802/ludwig. The target for this PR is uber/ludwig' branch 'tf2_porting. If this works, then I can just add commits to this PR.

The Docker image I use for my development environment is built with the updated requirements.txt on the tf2_porting branch, which contains

tensorflow==2.1
tensorflow-addons

With this initial set of commits, train completes and some, not all, of the data are saved for TensorBoard. For my test, I'm using the Titanic Survivor example. Here is the log from training.
tf2_sample_train_log.txt

Here is screenshot of TensorBoard for the data that was collected.
Screen Shot 2020-02-15 at 01 12 30

Let me know what you think.

@jimthompson5802
Copy link
Collaborator Author

jimthompson5802 commented Feb 15, 2020

I thought it would be good to establish a baseline for the unit tests. In doing this, I noticed strange behavior. The initial run of the unit tests led to this result:

============================= test session starts ==============================
platform linux -- Python 3.6.9, pytest-5.3.5, py-1.8.1, pluggy-0.13.1
rootdir: /opt/project
plugins: typeguard-2.7.1
collected 81 items

../../tests/integration_tests/test_api.py .                              [  1%]
../../tests/integration_tests/test_contrib_wandb.py FE                   [  2%]
../../tests/integration_tests/test_experiment.py EEEEEEEEEEEEEEEEEEEEEEEE [ 17%]
EEEEEEEEEEEEEE                                                           [ 25%]
../../tests/integration_tests/test_kfold_cv.py EEEEEE                    [ 29%]
../../tests/integration_tests/test_server.py EE                          [ 30%]
../../tests/integration_tests/test_visualization.py EEEEEEEEEEEEEEEEEEEE [ 43%]
EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE                                         [ 62%]
../../tests/integration_tests/test_visualization_api.py EEEEEEEEEEEEEEEE [ 72%]
EEEEEEEEEEEEEEEEEEEEEEEEEEEE                                             [ 90%]
../../tests/ludwig/models/modules/test_encoder.py EEEEEEEE               [ 95%]
../../tests/ludwig/utils/test_data_utils.py EE                           [ 96%]
../../tests/ludwig/utils/test_image_utils.py EEEE                        [ 98%]
../../tests/ludwig/utils/test_normalization.py EE                        [100%]

==================================== ERRORS ====================================

This is the complete log file
pytest_all_output.txt

All the errors were instances of this example:

____________ ERROR at teardown of test_experiment_sequence_combiner ____________

self = <contextlib._GeneratorContextManager object at 0x7fc6dc29f5f8>
type = None, value = None, traceback = None

    def __exit__(self, type, value, traceback):
        if type is None:
            try:
>               next(self.gen)
E               OSError: [Errno 29] Illegal seek

/usr/lib/python3.6/contextlib.py:88: OSError

Not clear, at least to me, how this may be related to TF2. I'll note the travis-ci run encounters the same problem as above.

Have you encountered this kind of problem with test_contrib_wandb.py before?

This is the results of the pytest run if test_contrib_wandb.py is removed.

============================= test session starts ==============================
platform linux -- Python 3.6.9, pytest-5.3.5, py-1.8.1, pluggy-0.13.1
rootdir: /opt/project
plugins: typeguard-2.7.1
collected 80 items

../../tests/integration_tests/test_api.py .                              [  1%]
../../tests/integration_tests/test_experiment.py ....FF..FFF..FF....     [ 25%]
../../tests/integration_tests/test_kfold_cv.py ...                       [ 28%]
../../tests/integration_tests/test_server.py F                           [ 30%]
../../tests/integration_tests/test_visualization.py .................... [ 55%]
......                                                                   [ 62%]
../../tests/integration_tests/test_visualization_api.py ................ [ 82%]
......                                                                   [ 90%]
../../tests/ludwig/models/modules/test_encoder.py FF..                   [ 95%]
../../tests/ludwig/utils/test_data_utils.py .                            [ 96%]
../../tests/ludwig/utils/test_image_utils.py ..                          [ 98%]
../../tests/ludwig/utils/test_normalization.py .                         [100%]

Here is the complete log file for this run:
pytest_output_initial.txt

Whatever the issue is with test_contrib_wandb.py it appears to corrupt the run-time environment for pytest.

@jimthompson5802
Copy link
Collaborator Author

Commit 1dbfdfb disables running tests/integration_tests/test_contrib_wandb.py temporarily. Once the issue with this test and TF2 is resolved, the test should be re-enabled.

@jimthompson5802
Copy link
Collaborator Author

@w4nderlust a question...this code fragment in ludwig/models/modules/loss_module.py is the reference tf2.keras.... guidance to me to focus on getting these functions to work or is just an artifact of earlier testing and not specific guidance on the approach?
https://github.com/uber/ludwig/blob/21b5f296e5955130a8826a7597b1ffd1201db4f3/ludwig/models/modules/loss_modules.py#L324-L328

@w4nderlust
Copy link
Collaborator

Have you encountered this kind of problem with test_contrib_wandb.py before?

No never. I think it's fine to disable the test temporarily and then to re-add it when the rest of the stuff is done. There are not so many tests failing, I expected worse to be honest, but there's still quite some work to do.

Regarding the regularizers, my understanding is that in TF2 all regularizers are actually keras regularizers. So I tried to replace the old ones with the keras ones, but it did not work. So I believe we can ignore the regularizers for now, port everything to TF2 and only later reintroduce them.

@jimthompson5802
Copy link
Collaborator Author

jimthompson5802 commented Feb 16, 2020

Commit 6689168 eliminates 5 out of 10 errors in the unit tests. Local running of unit tests:

============================= test session starts ==============================
platform linux -- Python 3.6.9, pytest-5.3.5, py-1.8.1, pluggy-0.13.1
rootdir: /opt/project
plugins: typeguard-2.7.1
collected 80 items

../../tests/integration_tests/test_api.py .                              [  1%]
../../tests/integration_tests/test_experiment.py ....F...FFF..F.....     [ 25%]
../../tests/integration_tests/test_kfold_cv.py ...                       [ 28%]
../../tests/integration_tests/test_server.py .                           [ 30%]
../../tests/integration_tests/test_visualization.py .................... [ 55%]
......                                                                   [ 62%]
../../tests/integration_tests/test_visualization_api.py ................ [ 82%]
......                                                                   [ 90%]
../../tests/ludwig/models/modules/test_encoder.py ....                   [ 95%]
../../tests/ludwig/utils/test_data_utils.py .                            [ 96%]
../../tests/ludwig/utils/test_image_utils.py ..                          [ 98%]
../../tests/ludwig/utils/test_normalization.py .                         [100%]

=================================== FAILURES ===================================
<<<<< DELETED LINES >>>>>
=========== 5 failed, 75 passed, 7040 warnings in 294.51s (0:04:54) ============

test_encoder.py and test_server.py are now error free. Eliminated two errors in test_experiment.py

@jimthompson5802
Copy link
Collaborator Author

Commit b2673eb fixes 1 of the 5 remaining errors in test_experiment.py

Local testing:

============================= test session starts ==============================
platform linux -- Python 3.6.9, pytest-5.3.5, py-1.8.1, pluggy-0.13.1
rootdir: /opt/project
plugins: typeguard-2.7.1
collected 80 items

../../tests/integration_tests/test_api.py .                              [  1%]
../../tests/integration_tests/test_experiment.py ....F...F.F..F.....     [ 25%]
../../tests/integration_tests/test_kfold_cv.py ...                       [ 28%]
../../tests/integration_tests/test_server.py .                           [ 30%]
../../tests/integration_tests/test_visualization.py .................... [ 55%]
......                                                                   [ 62%]
../../tests/integration_tests/test_visualization_api.py ................ [ 82%]
......                                                                   [ 90%]
../../tests/ludwig/models/modules/test_encoder.py ....                   [ 95%]
../../tests/ludwig/utils/test_data_utils.py .                            [ 96%]
../../tests/ludwig/utils/test_image_utils.py ..                          [ 98%]
../../tests/ludwig/utils/test_normalization.py .                         [100%]

=================================== FAILURES ===================================
<<<<< DELETED LINES >>>>>
=========== 4 failed, 76 passed, 6972 warnings in 297.38s (0:04:57) ============

@jimthompson5802
Copy link
Collaborator Author

Just a quick update...3 of the 4 remaining errors in the unit tests relate to sequence related models. re: the seq2seq test, made some progress in the sense that errors occur later in the processing. Slowly reverse engineering conversion from tf.contrib api semantics to tfa semantics. Commit 3994e2c shows work to date. Let me know if I'm heading down the wrong path.

@jimthompson5802
Copy link
Collaborator Author

jimthompson5802 commented Feb 17, 2020

a quick update....I'm focused on fixing the multiple seq2seq unit test. I'm assuming getting this to work will inform how to resolve the other failing tests.

From what I can tell I've I accounted for all "input parameters" and converted the deprecated tf.contrib.seq2seq.TrainingHelper to the TF2 addons equivalent and mapped them to the equivalent in TF2.

Although multiple seq2seq test fails, the unit test fail later in processing than when I first started working on it. I'm assuming that's progress. :-) Current error appears to involve some dependency on the legacy tf.contrib.seq2seq.TrainingHelper class attributes/methods.

I've pushed all changes to-date to this PR. Any feedback will be appreciated, especially feedback on if the data structures are mapped correctly to the TF2 equivalent usage.

@jimthompson5802
Copy link
Collaborator Author

@w4nderlust Re: an earlier comment

There are not so many tests failing, I expected worse to be honest, but there's still quite some work to do.

You may have noticed that the number of fails in the unit test took a turn for the worse in some of the commits (04470ae 7e4261c). This occurred because I changed line of code https://github.com/uber/ludwig/blob/2f55f8bd4922890a738d89dc10d953298037e25d/ludwig/models/modules/recurrent_modules.py#L36 to use tf2.keras.layers.SimpleRNNCell.

For the moment I went back to tf.nn.rnn_cell.BasicRNNCell to not distract my focus on getting the remaining tests converted to use the tensorflow_addons package. However, once that is taken care of and we start making other changes, the we'll see new types of failures in the unit tests.

@jimthompson5802
Copy link
Collaborator Author

I've been looking at the last set of errors and could use a second set of eyes.

The tests now fail with

      try:
        c_op = c_api.TF_FinishOperation(op_desc)
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
>       raise ValueError(str(e))
E       ValueError: Shape must be rank 1 but is rank 0 for 'sequence_5CF70/predictions_sequence_5CF70/rnn_decoder/decoder/concat_3' (op: 'ConcatV2') with input shapes: [1], [], [].

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1622: ValueError

As I understand the error, the code is expecting a vector but received instead a scalar value. Likely cause is that I mis-mapped a data structure when I converted to the tfa api set.

Here is an example error stack trace

../../tests/integration_tests/test_experiment.py:83: in run_experiment
    exp_dir_name = full_experiment(**args)
../../ludwig/experiment.py:349: in full_experiment
    **kwargs
../../ludwig/experiment.py:124: in experiment
    **kwargs
../../ludwig/train.py:354: in full_train
    debug=debug
../../ludwig/train.py:501: in train
    debug=debug
../../ludwig/models/model.py:116: in __init__
    **kwargs
../../ludwig/models/model.py:192: in __build
    is_training=self.is_training
../../ludwig/models/outputs.py:45: in build_outputs
    **kwargs
../../ludwig/models/outputs.py:98: in build_single_output
    **kwargs
../../ludwig/features/base_feature.py:314: in concat_dependencies_and_build_output
    **kwargs
../../ludwig/features/sequence_feature.py:265: in build_output
    kwarg=kwargs
../../ludwig/features/sequence_feature.py:295: in build_sequence_output
    regularizer=regularizer
../../ludwig/features/sequence_feature.py:392: in sequence_predictions
    is_timeseries=is_timeseries
../../ludwig/models/modules/sequence_decoders.py:123: in __call__
    regularizer=regularizer
../../ludwig/models/modules/recurrent_modules.py:469: in recurrent_decoder
    inputs=encoder_outputs
../../ludwig/models/modules/recurrent_modules.py:430: in decode
    decoder_init_kwargs={'initial_state': initial_state}
/usr/local/lib/python3.6/dist-packages/tensorflow_addons/seq2seq/decoder.py:346: in dynamic_decode
    decoder.output_dtype,
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/nest.py:568: in map_structure
    structure[0], [func(*x) for x in entries],
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/nest.py:568: in <listcomp>
    structure[0], [func(*x) for x in entries],
/usr/local/lib/python3.6/dist-packages/tensorflow_addons/seq2seq/decoder.py:343: in <lambda>
    _prepend_batch(decoder.batch_size, shape), dtype=dtype
/usr/local/lib/python3.6/dist-packages/tensorflow_addons/seq2seq/decoder.py:534: in _prepend_batch
    return tf.concat(([batch_size], shape), axis=0)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/dispatch.py:180: in wrapper
    return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/array_ops.py:1517: in concat
    return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/gen_array_ops.py:1126: in concat_v2
    "ConcatV2", values=values, axis=axis, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/op_def_library.py:742: in _apply_op_helper
    attrs=attr_protos, op_def=op_def)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:3322: in _create_op_internal
    op_def=op_def)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1786: in __init__
    control_input_ops)

The full pytest log file if needed.
pytest_output_log2.txt

@w4nderlust
Copy link
Collaborator

Your work has been great! I'm suggesting refocus just because it seems unfair to me to ask you to spend your time to get that deeper level of understanding. On the other hand, if you are willing to and you believe it's something beneficial to you, we can definitely keep on working on it. It's your call, just don't want you to feel obliged because I understand how much of a learning curve there is :)

@jimthompson5802
Copy link
Collaborator Author

I'm still willing to work on the TF1 to TF2 upgrade. While it is a lot of effort and slow progress, I'm getting practical experience in making deep learning software work. So if you're OK with the current pace of work and me bugging you with questions, I'd like to continue.

We can have this understanding. If, for any reason, you want to take over the task, just let me know that you want to take over. At the same time, if I feel overwhelmed, I'll reach out to you and pass the work back to you.

These can help me in my work.

  • Your insight to code segments, such as

Most of them are just there for timeseries decoding, which was attempted, made partially working, but never completed. So I would say most of those classes can go and dont need to be ported,

  • Or guidance such as this

Right now Ludwig creates the model by adding operations to a graph, and then that graph is used to start a session that is run in session.run calls. Things would have to change to become functions that compute prediction and probabilities, and those functions are called both when training and when predicting.

Basically setting boundaries for the work will help me focus in the right area.

@w4nderlust
Copy link
Collaborator

I'm still willing to work on the TF1 to TF2 upgrade. While it is a lot of effort and slow progress, I'm getting practical experience in making deep learning software work. So if you're OK with the current pace of work and me bugging you with questions, I'd like to continue.

We can have this understanding. If, for any reason, you want to take over the task, just let me know that you want to take over. At the same time, if I feel overwhelmed, I'll reach out to you and pass the work back to you.

that's a perfect arrangement for me.

These can help me in my work.

  • Your insight to code segments, such as

Most of them are just there for timeseries decoding, which was attempted, made partially working, but never completed. So I would say most of those classes can go and dont need to be ported,

Let me do a pass on the current recurrent_modules in the PR to remove everything related to timeseries, that way things would be much cleaner to begin with.

  • Or guidance such as this

Right now Ludwig creates the model by adding operations to a graph, and then that graph is used to start a session that is run in session.run calls. Things would have to change to become functions that compute prediction and probabilities, and those functions are called both when training and when predicting.

Basically setting boundaries for the work will help me focus in the right area.

makes sense. Part of the reason it is difficult is because you expect some things to work in a certain way in TF2, that you just need to replace calls with corresponsing ones, but than we discover it's not the case, so one needs to adjust. But that's fine, it's an iterative process.

Regarding the graph building part, the function you should look at is model.__build(). That function reads the model definiton and depending on its contents build a computation graph. It does it in modular way, for inputs, outputs and combiners it calls the respective functions, and then at the end it collects the tensors that need to be provided to sess.run for both training and prediction.
Each of these modules gets the model definition part they need and adds the operations to the graph. For isntance the inputs part cicles over the inputs and depending on the type of the input icalls the build_input function of the specific feature. For isntance if you specify an input sequence feature, build_inputs() is called https://github.com/uber/ludwig/blob/master/ludwig/models/model.py#L162 which in turn cycles through the input features and calls build_single_input https://github.com/uber/ludwig/blob/master/ludwig/models/inputs.py#L38 which calls the build_input function of the feature depending on its type https://github.com/uber/ludwig/blob/master/ludwig/models/inputs.py#L61-L69 . In this case, the build_input function of a sequence encoder would be used https://github.com/uber/ludwig/blob/master/ludwig/features/sequence_feature.py#L152 . Then within that function in msot cases an encoder would be selected, like a parallen_cnn or an rnn etc. depending on what the user specifies. You can follow the implementation of build_input to get an idea. Then the same concept is applied to combiners and outputs. Outputs are a bit more complicated, because they have nodes both for prediction and for training (like a placeholder for the ground truth and nodes for losses and measures that are important at training time, at test time but not at training time). check out the categorical output feature class as a comprehensive example and the binary one as a simple example. What i suggest you to do is to put a breakpoint at the beginning of model.__build and follow all the nested steps up until the end of that function, that will give you a great example of how things work. Once you have that understanding, you'll be much better equipped to then figure out how to make the whole thing eager.

@jimthompson5802
Copy link
Collaborator Author

Thank you for the guidance. As a proof-of-concept for the above guidance, I'm planning to demonstrate eager execution on this model:

    input_features = [
        {'name': 'x1', 'type': 'numerical'},
        {'name': 'x2', 'type': 'numerical'},
        {'name': 'x3', 'type': 'numerical'}
    ]
    output_features = [
        {'name': 'y', 'type': 'numerical'}
    ]

    model_definition = {
        'input_features': input_features,
        'output_features': output_features,
        'combiner': {
            'type': 'concat',  # 'concat',
            'num_fc_layers': 2,
            'fc_size': 64
        },
        'training': {'epochs': 20}
    }

This way I only have to focus on one type of feature and a simple fully connected layer. If I can get this to work for train, validation and test functions, this would form the foundation for the rest of the work.

I'll provide updates as the work progresses

@jimthompson5802
Copy link
Collaborator Author

Kind of "Daily stand-up" :-)

What was accomplished

This code and resulting output is a proof-of-concept showing creation and use
the relevant TF2 functions and confirmation on the network structure.

from functools import reduce

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model

raw_df = pd.read_csv('./data/train.csv')
raw_X = raw_df.loc[:,'x1':'x3']
raw_y = raw_df['y']
train_X, test_X, train_y, test_y = train_test_split(raw_X, raw_y, test_size=0.2)
print(train_X.shape, test_X.shape, train_y.shape, test_y.shape)

#
# after the input features are combined
# 
inputs = Input(shape=(train_X.shape[1],))

###
# For fully connected layer add code here: 
# two alternative methods
###
# Alternative 1: standard out-of-the-book method.  One drawback is keeping track of
# argument, e.g., "x" as the functions are created 
# x = Dense(64, activation='relu')(inputs)
# x = Dense(64, activation='relu')(x)
#predictions = Dense(1, activation='linear')(x)

# Alternative 2:  just append functions to a list and use reduce() to combine them
# this has the advantage of not having to explicitly track the argument "x"
# 'function_layers' will be added as an instance variable to 'Model' class 
function_layers = []
function_layers.append(inputs)
function_layers.append(Dense(64, activation='relu'))
function_layers.append(Dense(64, activation='relu'))
function_layers.append(Dense(1, activation='linear'))
predictions = reduce(lambda x, y: y(x), function_layers)


# create model structure
model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='adam',
              loss='MSE',
              metrics=['mse'])

# train model
print("\ntraining")
model.fit(train_X, train_y, batch_size=5, epochs=10)  # starts training


# evalutate model on test set
print("\nrunning evlaution")
eval = model.evaluate(test_X, test_y)
print(eval)

print(model.summary())


<<<<<<<< Output >>>>>>>>>>
e012dd85d84b:python -u /opt/project/sandbox/tf2_port/sandbox_model_keras.py
2020-02-27 11:57:44.979921: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory
2020-02-27 11:57:44.980095: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory
2020-02-27 11:57:44.980117: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
(799, 3) (200, 3) (799,) (200,)
2020-02-27 11:57:45.714441: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-02-27 11:57:45.714487: E tensorflow/stream_executor/cuda/cuda_driver.cc:351] failed call to cuInit: UNKNOWN ERROR (303)
2020-02-27 11:57:45.714511: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (e012dd85d84b): /proc/driver/nvidia/version does not exist
2020-02-27 11:57:45.714726: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-02-27 11:57:45.721055: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2791110000 Hz
2020-02-27 11:57:45.722479: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5fd0b10 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-02-27 11:57:45.722523: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version

training
Train on 799 samples
Epoch 1/10
799/799 [==============================] - 1s 650us/sample - loss: 0.1142 - mse: 0.1142
Epoch 2/10
799/799 [==============================] - 0s 332us/sample - loss: 0.0877 - mse: 0.0877
Epoch 3/10
799/799 [==============================] - 0s 316us/sample - loss: 0.0869 - mse: 0.0869
Epoch 4/10
799/799 [==============================] - 0s 582us/sample - loss: 0.0850 - mse: 0.0850
Epoch 5/10
799/799 [==============================] - 0s 336us/sample - loss: 0.0867 - mse: 0.0867
Epoch 6/10
799/799 [==============================] - 0s 345us/sample - loss: 0.0849 - mse: 0.0849
Epoch 7/10
799/799 [==============================] - 0s 333us/sample - loss: 0.0847 - mse: 0.0847
Epoch 8/10
799/799 [==============================] - 0s 331us/sample - loss: 0.0852 - mse: 0.0852
Epoch 9/10
799/799 [==============================] - 1s 845us/sample - loss: 0.0845 - mse: 0.0845
Epoch 10/10
799/799 [==============================] - 0s 321us/sample - loss: 0.0826 - mse: 0.0826

running evlaution
200/200 [==============================] - 0s 299us/sample - loss: 0.1009 - mse: 0.1009
[0.1008630508184433, 0.100863054]


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 3)]               0         
_________________________________________________________________
dense (Dense)                (None, 64)                256       
_________________________________________________________________
dense_1 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 65        
=================================================================
Total params: 4,481
Trainable params: 4,481
Non-trainable params: 0
_________________________________________________________________
None


Process finished with exit code 0

My current thinking where the above code fragments are inserted in the Ludwig code structure

To fit in the above code fragments, I'm expecting to remove/replace surrounding code in the cited locations.

This fragment

inputs = Input(shape=(train_X.shape[1],))

is added in this region
https://github.com/uber/ludwig/blob/1fda07c555f3f3d3a9f989ab63441c8021a50391/ludwig/models/combiners.py#L83-L104

I'm leaning toward Alternative 2. This fragment

# Alternative 2:  just append functions to a list and use reduce() to combine them
# this has the advantage of not having to explicitly track the argument "x"
# 'function_layers' will be added as an instance variable to 'Model' class 
function_layers = []
function_layers.append(inputs)
function_layers.append(Dense(64, activation='relu'))
function_layers.append(Dense(64, activation='relu'))
function_layers.append(Dense(1, activation='linear'))
predictions = reduce(lambda x, y: y(x), function_layers)

is added in this region https://github.com/uber/ludwig/blob/1fda07c555f3f3d3a9f989ab63441c8021a50391/ludwig/models/modules/fully_connected_modules.py#L122-L137

This fragment

model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='adam',
              loss='MSE',
              metrics=['mse'])

is added in this region https://github.com/uber/ludwig/blob/1fda07c555f3f3d3a9f989ab63441c8021a50391/ludwig/models/model.py#L203-L222

This fragment

print("\ntraining")
model.fit(train_X, train_y, batch_size=5, epochs=10)  # starts training

is added in this region https://github.com/uber/ludwig/blob/1fda07c555f3f3d3a9f989ab63441c8021a50391/ludwig/models/model.py#L531-L566

This fragment

print("\nrunning evlaution")
eval = model.evaluate(test_X, test_y)

is added in this region https://github.com/uber/ludwig/blob/1fda07c555f3f3d3a9f989ab63441c8021a50391/ludwig/models/model.py#L857-L893

Next steps

  • Make necessary modification to the Ludwig code base to incorporate the above.
  • Demonstrate successfully training and evaluation of this simple Ludwig model
input_features = [
    {'name': 'x1', 'type': 'numerical'},
    {'name': 'x2', 'type': 'numerical'},
    {'name': 'x3', 'type': 'numerical'}
]
output_features = [
    {'name': 'y', 'type': 'numerical'}
]

model_definition = {
    'input_features': input_features,
    'output_features': output_features,
    'combiner': {
        'type': 'concat',  # 'concat',
        'num_fc_layers': 2,
        'fc_size': 64
    },
    'training': {'epochs': 20}
}

Blockers

None

@w4nderlust
Copy link
Collaborator

A few comments on this.

  1. The starting point should be this isntead: https://www.tensorflow.org/tutorials/quickstart/advanced The reason is that this matches much better with Ludwig's training look, it uses functions for obtaining predictions and other functions to compute losses and objects to define the optimization steps. Kera compile function hides all of this, and those things are needed in Ludwig
  2. The first option, the functional/eager approach is what we want to lean towards, and the example that i sent you only coversa that one.
  3. Looking at your case, you have 3 numerical features. In Ludwig they are treated separately not as a single vector, then the combinar concatenates them. The input part in your code maps to numerical_inputFeatures._get_placeholder(). It will be called 3 times, and the 3 outputs would be given as inputs to the combiner, which will contain the dense layers, so it will be inside combiners.ConcatCombiner.
  4. model.fit would not be there, if you look at the example I psoted there's a more direct mapping with Ludwig's training loop
  5. also the evaluation part will not be a single call, but a bunch of them, but still within the model.evaluate function.
    Let me know if this is not clear.

@w4nderlust w4nderlust merged commit 681ece0 into ludwig-ai:tf2_porting Feb 27, 2020
Ludwig Development automation moved this from In progress to Done Feb 27, 2020
@w4nderlust
Copy link
Collaborator

Oh, weird, I added a commit to remove the timeseries stuff from recurrent modules, and i pushed the suaul way but it merge it in the tf2_porting branch. I guess we can continue from there

@w4nderlust w4nderlust moved this from Done to In progress in Ludwig Development Feb 27, 2020
@jimthompson5802
Copy link
Collaborator Author

I understand this branch was accidentally merged. When you write, "I guess we can continue from there," are you referring to me to continue pushing changes to this branch jimthompson5802:tf2_porting?

@w4nderlust
Copy link
Collaborator

I think so. Anyway, it was merged in the tf2_porting branch not in master, which is ok.

@jimthompson5802
Copy link
Collaborator Author

I believe I have an understanding of how the custom eager mode custom training loop works. Using the advanced tutorial you pointed out, I worked out a native TF2 custom training example. This code will serve as guide for future work. My next step is begin the work to modify Ludwig for eager mode.

Note: While working on the sample program, it appeared that training was not converging. While researching this, I found this issue. Not clear if this is relevant or not. Any thoughts about this?

This is the native TF2 sample program


import pandas as pd

from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model


raw_df = pd.read_csv('./data/train.csv')

raw_X = raw_df.loc[:,'x1':'x3']
raw_y = raw_df['y']

train_X, test_X, train_y, test_y = train_test_split(raw_X, raw_y, test_size=0.2)
print(train_X.shape, test_X.shape, train_y.shape, test_y.shape)

dataset_train = tf.data.Dataset.from_tensor_slices(
    (
        tf.cast(train_X.values, tf.float32),
        tf.cast(train_y.values, tf.float32)
    )
)

dataset_test = tf.data.Dataset.from_tensor_slices(
    (
        tf.cast(test_X.values, tf.float32),
        tf.cast(test_y.values, tf.float32)
    )
)


print("<<<<<< CUSTOM TRAINING LOOP>>>>>>>>>>")
class MyModel(Model):
    def __init__(self):
        super().__init__()

        self.d1 = Dense(64, activation='relu')
        self.d2 = Dense(64, activation='relu')
        self.prediction = Dense(1, activation='linear')

    def call(self, inputs):
        x = self.d1(inputs)
        x = self.d2(x)
        return self.prediction(x)

    def add_layer(self, name, layer):
        self.model_layers[name] = layer


model = MyModel()

loss_object = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_metric = tf.keras.metrics.MeanSquaredError(name='train_metric')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_metric = tf.keras.metrics.MeanSquaredError(name='test_metric')

tf.config.experimental_run_functions_eagerly(True)

@tf.function
def train_step(model, optimizer, loss_object, x, y):
    #  issue?: https://github.com/tensorflow/tensorflow/issues/28901
    #y = y[:, tf.newaxis]
    with tf.GradientTape() as tape:
        y_hat = model(x, training=True)
        # print("in training", y.shape, y_hat.shape)
        loss = loss_object(y, y_hat)
        # print("training lost:", loss.numpy())
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_metric(y, y_hat)


@tf.function
def test_step(model, loss_object, x, y):
    #y = y[:, tf.newaxis]
    y_hat = model(x, training=False)
    # print("in testing", y.shape, y_hat.shape)
    t_loss = loss_object(y, y_hat)

    test_loss(t_loss)
    test_metric(y, y_hat)


EPOCHS = 20

for epoch in range(EPOCHS):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_metric.reset_states()
    test_loss.reset_states()
    test_metric.reset_states()

    for X, y in dataset_train.batch(16):
        #print(X.shape)
        train_step(model, optimizer, loss_object, X, y)

    for X, y in dataset_test.batch(16):
        test_step(model, loss_object, X, y)

    template = 'Epoch {}, train Loss: {}, : train metric {}, Test Loss: {}, Test Metric: {}'
    print(template.format(epoch+1,
                        train_loss.result(),
                        train_metric.result(),
                        test_loss.result(),
                        test_metric.result()))

model.summary()


print('<<<<<<<<<<<<<<< SIMPLE MODEL >>>>>>>>>>>>>>>')
# Compare results with simple model
inputs = Input(shape=(3,))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
prediction = Dense(1)(x)

model = Model(inputs, prediction)

model.compile(optimizer='adam',
              loss='MSE',
              metrics=['mse'])
print("\ntraining")
model.fit(dataset_train.batch(16), epochs=20)  # starts training
print("\nrunning evlaution")
eval = model.evaluate(test_X, test_y)
print(eval)

print(model.summary())


--------------------------------------------------------
2e19d158f757:python -u /opt/project/sandbox/tf2_port/sandbox_model_keras.py
2020-02-29 01:58:05.937796: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory
2020-02-29 01:58:05.937935: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory
2020-02-29 01:58:05.937956: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
(799, 3) (200, 3) (799,) (200,)
<<<<<< CUSTOM TRAINING LOOP>>>>>>>>>>
2020-02-29 01:58:06.569722: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-02-29 01:58:06.569745: E tensorflow/stream_executor/cuda/cuda_driver.cc:351] failed call to cuInit: UNKNOWN ERROR (303)
2020-02-29 01:58:06.569773: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (2e19d158f757): /proc/driver/nvidia/version does not exist
2020-02-29 01:58:06.570113: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-02-29 01:58:06.575861: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2791535000 Hz
2020-02-29 01:58:06.576537: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x4599350 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-02-29 01:58:06.576557: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
Epoch 1, train Loss: 0.12295467406511307, : train metric 0.12295467406511307, Test Loss: 0.09914495050907135, Test Metric: 0.09914495050907135
Epoch 2, train Loss: 0.09224193543195724, : train metric 0.09224193543195724, Test Loss: 0.0926184132695198, Test Metric: 0.0926184132695198
Epoch 3, train Loss: 0.08829674869775772, : train metric 0.08829674869775772, Test Loss: 0.09075924009084702, Test Metric: 0.09075924009084702
Epoch 4, train Loss: 0.08678507059812546, : train metric 0.08678507059812546, Test Loss: 0.09065888822078705, Test Metric: 0.09065888822078705
Epoch 5, train Loss: 0.0861264318227768, : train metric 0.0861264318227768, Test Loss: 0.0906728133559227, Test Metric: 0.0906728133559227
Epoch 6, train Loss: 0.08584615588188171, : train metric 0.08584615588188171, Test Loss: 0.09036322683095932, Test Metric: 0.09036322683095932
Epoch 7, train Loss: 0.085653156042099, : train metric 0.085653156042099, Test Loss: 0.09012754261493683, Test Metric: 0.09012754261493683
Epoch 8, train Loss: 0.08558208495378494, : train metric 0.08558208495378494, Test Loss: 0.09015573561191559, Test Metric: 0.09015573561191559
Epoch 9, train Loss: 0.0854942575097084, : train metric 0.0854942575097084, Test Loss: 0.09005601704120636, Test Metric: 0.09005601704120636
Epoch 10, train Loss: 0.08544553816318512, : train metric 0.08544553816318512, Test Loss: 0.09012728184461594, Test Metric: 0.09012728184461594
Epoch 11, train Loss: 0.08539620041847229, : train metric 0.08539620041847229, Test Loss: 0.09005728363990784, Test Metric: 0.09005728363990784
Epoch 12, train Loss: 0.08536931127309799, : train metric 0.08536931127309799, Test Loss: 0.0901104137301445, Test Metric: 0.0901104137301445
Epoch 13, train Loss: 0.08532935380935669, : train metric 0.08532935380935669, Test Loss: 0.09014303237199783, Test Metric: 0.09014303237199783
Epoch 14, train Loss: 0.0852506086230278, : train metric 0.0852506086230278, Test Loss: 0.09023386985063553, Test Metric: 0.09023386985063553
Epoch 15, train Loss: 0.08521439880132675, : train metric 0.08521439880132675, Test Loss: 0.09024890512228012, Test Metric: 0.09024890512228012
Epoch 16, train Loss: 0.08518398553133011, : train metric 0.08518398553133011, Test Loss: 0.09023744612932205, Test Metric: 0.09023744612932205
Epoch 17, train Loss: 0.08514454215765, : train metric 0.08514454215765, Test Loss: 0.09019779413938522, Test Metric: 0.09019779413938522
Epoch 18, train Loss: 0.0850943773984909, : train metric 0.0850943773984909, Test Loss: 0.09019617736339569, Test Metric: 0.09019617736339569
Epoch 19, train Loss: 0.08504881709814072, : train metric 0.08504881709814072, Test Loss: 0.09022630006074905, Test Metric: 0.09022630006074905
Epoch 20, train Loss: 0.08499737083911896, : train metric 0.08499737083911896, Test Loss: 0.0901600793004036, Test Metric: 0.0901600793004036
Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  256       
_________________________________________________________________
dense_1 (Dense)              multiple                  4160      
_________________________________________________________________
dense_2 (Dense)              multiple                  65        
=================================================================
Total params: 4,481
Trainable params: 4,481
Non-trainable params: 0
_________________________________________________________________
<<<<<<<<<<<<<<< SIMPLE MODEL >>>>>>>>>>>>>>>

training
Train for 50 steps
Epoch 1/20
50/50 [==============================] - 0s 8ms/step - loss: 0.1314 - mse: 0.1315
Epoch 2/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0936 - mse: 0.0936
Epoch 3/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0892 - mse: 0.0892
Epoch 4/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0875 - mse: 0.0875
Epoch 5/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0866 - mse: 0.0866
Epoch 6/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0862 - mse: 0.0862
Epoch 7/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0859 - mse: 0.0859
Epoch 8/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0857 - mse: 0.0857
Epoch 9/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0856 - mse: 0.0856
Epoch 10/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0855 - mse: 0.0855
Epoch 11/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0854 - mse: 0.0854
Epoch 12/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0853 - mse: 0.0853
Epoch 13/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0853 - mse: 0.0853
Epoch 14/20
50/50 [==============================] - 0s 10ms/step - loss: 0.0852 - mse: 0.0852
Epoch 15/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0852 - mse: 0.0852
Epoch 16/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0851 - mse: 0.0852
Epoch 17/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0851 - mse: 0.0851
Epoch 18/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0851 - mse: 0.0851
Epoch 19/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0850 - mse: 0.0850
Epoch 20/20
50/50 [==============================] - 0s 8ms/step - loss: 0.0850 - mse: 0.0850

running evlaution
200/200 [==============================] - 0s 143us/sample - loss: 0.0902 - mse: 0.0902
[0.09019238114356995, 0.09019238]
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 3)]               0         
_________________________________________________________________
dense_3 (Dense)              (None, 64)                256       
_________________________________________________________________
dense_4 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 65        
=================================================================
Total params: 4,481
Trainable params: 4,481
Non-trainable params: 0
_________________________________________________________________
None

Process finished with exit code 0

@w4nderlust
Copy link
Collaborator

Following this comment you may want to check the shape of y and y_hat.

@jimthompson5802
Copy link
Collaborator Author

I'm now looking at how to adapt Ludwig data pipeline to support the design described in the previous post. From what I can tell, the current design has training data converted primarily into a dictionary structure, which is fed into the current tf1 training loop.

For the TF2 eager execution, I have to convert the pre-processed training data to a tf.data.Dataset() structure. I'm looking at doing the conversation in this section of code. By doing the conversion here, I leave the current ludwig data preprocessing intact. This seems the place that allows for minimal changes. https://github.com/uber/ludwig/blob/4039f93e605ed929e4f4a75c725bb70e148df92b/ludwig/data/dataset.py#L21-L37

Does this sound like a reasonable approach?

Are there other factors I need to be aware of? For example is the ludwig.data.dataset.Dataset class used for purposes beyond training, evaluation and prediction?

@w4nderlust
Copy link
Collaborator

I'm not sure you actually need to do the conversion to a tf dataset. I believe the inputs to MyModel.call() in your example will gradly accept numpy ndarrays, which is what Ludwig batcher provide. Actually Ludwig batcer provides a disctionary with the column names and ndarrays as values, so internally in the overall model call the different columns ndarrays need to be dispatched to the right subparts of the model (the different input encoders). The way to do it is by using the name of the column. If my model contains a dictionary of input_feature_name -> encoder function (let's call it self.input_feature_ecoders), when a new input batch arrives you just use the name of the column and obtain the output of the encode, and you do the same for the decoders, something like:

def __call__(self, inputs):
  encoder_outputs = []
  for input_feature_name, input_values in inputs:
    encoder_output = self.input_feature_ecoders[input_feature_name](input_values)
    encoder_outputs.append(encoder_output)
  combiner_outputs = self.combiner(ncoder_outputs)
  output_tensors = {}
  for output_feature_name, decoder in self.output_feature_decoders:
    decoder_output = decoder(combiner_outputs)
    output_tensors[output_feature_name] = decoder_output
  return output_tensors

does this make sense?

@jimthompson5802
Copy link
Collaborator Author

OK..as I understand the guidance, I should add these two dictionary attributes to MyModel class:

  • input_feature_encoders
  • output_feature_encoders

The dictionary key values will be the input/output column names, i.e., 'x1', 'x2', 'y'. The dictionary values will be the functions that encode the data type of the column. Depending on the answer to the next question, I may have follow-up question on the encoding functions.

re:

numpy ndarrays, which is what Ludwig batcher provide

Please point me to the relevant code section for the batcher.

The values returned by the "Ludwig batcher" have these already been pre-processed? For example, if the column is a numerical data type, has it been "z-scored" if that was requested and missing values filled in. Or does this function encoder_output = self.input_feature_ecoders[input_feature_name](input_values) have to take care of these operations? I'm trying to understand how functionality is partitioned in Ludwig.

@jimthompson5802
Copy link
Collaborator Author

I see now how Batcher function works. So no need to point out code section.

I'm still curious to know if the values returned by

batch = batcher.next_batch()

have already been pre-processed?

@jimthompson5802
Copy link
Collaborator Author

In thinking more about the recent guidance, I realize I may have misunderstood what you meant.

Initially, when I read the term "encoder" I thought about preparing the data for training. That is why I had those questions about pre-processing.

However, now I think when you used the term "encoder" and "decoder", this is what you had in mind. Is this correct?
encoder-decoder

@jimthompson5802
Copy link
Collaborator Author

With some experimentation and use of the debugger, I believe I have the answer to my question, "Has the data been pre-processed?" at this point in the code

batch = batcher.next_batch()

From what I can tell, the answer is "Yes." If this is true, then some of the changes I thought needed to happen is no longer relevant.

Again, thank you for the guidance.

@w4nderlust
Copy link
Collaborator

w4nderlust commented Mar 1, 2020

I'm still curious to know if the values returned by the batcher have already been pre-processed?

Yes data is already preprocessed when it is provided as a batch. There are other PRs for refactoring that preprocessing code to make it mode generic, byt they could be completely independent form this I believe.

Ans yes the image is an example of an autoencoder with an encoder and a decoder, In Ludwig you have one different encoder for each input feature, a different decoder for each output feature and a combiner. Maybe the presentations here https://w4nderlu.st/projects/ludwig or Ludwig's whitepaper can help you getting a better understanding about this notion.

Let me know if you have other questions.

@borisdayma
Copy link
Contributor

borisdayma commented Mar 19, 2020

Commit 1dbfdfb disables running tests/integration_tests/test_contrib_wandb.py temporarily. Once the issue with this test and TF2 is resolved, the test should be re-enabled.

I had a similar issue due to the fact that wandb uses the context manager which conflicts with pytest. It was solved by using a fixture such as with wandb.init(…) as run: return run.

In my case wandb.init() was called outside of the callback which has some pros (more flexibility to the user) but cannot really be done here. A simpler solution would be to just use a mock.

This may not be the problem here since it is working in master branch.

In any case let me know if you need any help with this test.

@w4nderlust
Copy link
Collaborator

w4nderlust commented Mar 19, 2020

@borisdayma thank you for offering to help! I this branch has many thing that are not working yet, so this is just one of the many, but when we get to the point where most things work, then we'll start work on the contribs, at that point your experience will likely be valuable. Please monitor the open PR (this one was merged by error) #646 . I will mention you there when needed, thank you again!

@w4nderlust w4nderlust moved this from In progress to Done in Ludwig Development Mar 27, 2020
@w4nderlust w4nderlust changed the title Tf2 porting TF2 porting: initial work Mar 27, 2020
@iamatsundere
Copy link

I'm stuck with converting tf.contrib.seq2seq.TrainingHelper to tf2.2, please help me.

@w4nderlust
Copy link
Collaborator

I'm stuck with converting tf.contrib.seq2seq.TrainingHelper to tf2.2, please help me.

use tensorflow_addons instead

@iamatsundere
Copy link

thanks @w4nderlust , i fixed it, but know I have a problem with

training_logits, _, _ = tfa.seq2seq.decoder.dynamic_decode(decoder=training_decoder,
                                                                   maximum_iterations=max_target_length,
                                                                   swap_memory=True,
                                                                   scope="Training_Decoder",
                                                                   decoder_init_input=dec_embed_input,
                                                                   decoder_init_kwargs={
                                                                       'initial_state': initial_state
                                                                   })

It always return TypeError: call() got an unexpected keyword argument 'training' although I followed the tutorial on stackoverflow. Can you explain for me why I got this issue?

@w4nderlust
Copy link
Collaborator

w4nderlust commented Jul 3, 2020

It's very difficult for me to help out without further context unfortuantely. Plus, this is not really Ludwig related :) What I suggest you to do is to look at Ludwig's Sequence decoders module and compare with your implementation. Hopefully it helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants