Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF2 porting: sequence feature #682

Merged
merged 53 commits into from Apr 20, 2020

Conversation

jimthompson5802
Copy link
Collaborator

Code Pull Requests

This is the start of porting sequence feature to TF2. There is still more to be done. Current state of code:

  • placed TF2 related stub methods in sequence_feature.py. These will be filled in as work progresses.
  • Created classes SequencePassthroughEncoder and SequenceEmbedEncoder as subclass of tf.keras.layers.Layer. Integrated existing functions to work with the Layer subclass.

From a small test both encoders work. The test passed an input array of shape [2, 5] through the SequencePassthroughEncoder and SequenceEmbedEncoder. Results of the test:

  • Passthrough encoder with reduce_output=None resulted in an output tensor with shape [2, 5, 1]
  • Embed encoder with embedding_size=3, reduce_output='sum' resulted in an output tensor with shape [2,3]
  • Embed encoder with embedding_size=3, reduce_output='None' resulted in an output tensor with shape [2, 5, 3]

Does this look correct?

Next steps for me is to get a working decoder and full_experiment() to complete successfully for a simple model.

This is the simple test case and its output.

import numpy as np
import tensorflow as tf

from ludwig.models.modules.sequence_encoders import SequencePassthroughEncoder, \
    SequenceEmbedEncoder

# setup input
input = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])

# test Passthrough encoder
print("Sample SequencePassthroughEncoder")
spt = SequencePassthroughEncoder()
output = spt(input)
print('input shape:', input.shape,'\nvalues:\n', input)
print('\noutput shape:\n', output.shape,'\nvalues\n', output.numpy())


# test embed encoder
print("\n\nSample SequenceEmbedEncoder")
vocab = list('abcdefghij')
print('vocab:', len(vocab), vocab)
emb = SequenceEmbedEncoder(vocab, embedding_size=3)
output2 = emb(input)
print('input shape:', input.shape,'\nvalues:\n', input)
print('\noutput shape:\n', output2.shape,'\nvalues\n', output2.numpy())


# test embed encoder with reduce_output=None
print("\n\nSample SequenceEmbedEncoder, reduce_output=None")
vocab = list('abcdefghij')
print('vocab:', len(vocab), vocab)
emb2 = SequenceEmbedEncoder(vocab, embedding_size=3, reduce_output='None')
output3 = emb2(input)
print('input shape:', input.shape,'\nvalues:\n', input)
print('\noutput shape:\n', output3.shape,'\nvalues\n', output3.numpy())

Test output

2516c6abd7cc:python -u /opt/project/sandbox/tf2_port/sequence_encoder_tester.py
2020-04-10 03:41:27.645738: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory
2020-04-10 03:41:27.645874: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory
2020-04-10 03:41:27.645895: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Sample SequencePassthroughEncoder
2020-04-10 03:41:28.838864: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-04-10 03:41:28.838974: E tensorflow/stream_executor/cuda/cuda_driver.cc:351] failed call to cuInit: UNKNOWN ERROR (303)
2020-04-10 03:41:28.839016: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (2516c6abd7cc): /proc/driver/nvidia/version does not exist
2020-04-10 03:41:28.839430: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-04-10 03:41:28.845809: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2791435000 Hz
2020-04-10 03:41:28.846760: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x439f600 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-04-10 03:41:28.846828: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
input shape: (2, 5) 
values:
 [[0 1 2 3 4]
 [5 6 7 8 9]]

output shape:
 (2, 5, 1) 
values
 [[[0.]
  [1.]
  [2.]
  [3.]
  [4.]]

 [[5.]
  [6.]
  [7.]
  [8.]
  [9.]]]


Sample SequenceEmbedEncoder
vocab: 10 ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
input shape: (2, 5) 
values:
 [[0 1 2 3 4]
 [5 6 7 8 9]]

output shape:
 (2, 3) 
values
 [[ 0.43313265  2.5553155   0.4227686 ]
 [-2.2676423   0.80896974 -0.28734756]]


Sample SequenceEmbedEncoder, reduce_output=None
vocab: 10 ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
input shape: (2, 5) 
values:
 [[0 1 2 3 4]
 [5 6 7 8 9]]

output shape:
 (2, 5, 3) 
values
 [[[ 0.          0.         -0.        ]
  [-0.3263557  -0.9351046   0.81713986]
  [ 0.10236788 -0.8854079   0.736264  ]
  [ 0.6721873  -0.6389749  -0.22343826]
  [ 0.9717505   0.37574863 -0.77134395]]

 [[-0.29198527  0.8138087  -0.12796235]
  [-0.7267971   0.70457935  0.5321882 ]
  [ 0.8782997   0.01055479  0.969434  ]
  [ 0.89878535  0.36403346  0.67668414]
  [ 0.55935645 -0.59581757  0.71213484]]]

Process finished with exit code 0

@w4nderlust
Copy link
Collaborator

This looks great, seems like a great starting point. All the shapes are correct

@jimthompson5802
Copy link
Collaborator Author

While working on the SequenceOutputFeature class, I just noticed a difference method signature between the __init__() methods for all the <Type>InputFeature and <Type>OutputFeature __init() methods that we have worked on up to now. This difference goes back to the original numerical feature conversion to TF2.

The method signature for input features are <Type>InputFeature.__init__(self, feature, encoder_obj=None)

The method signature for output features are <Type>OutputFeature.__init__(self, feature).

Output features are missing the option to specify the decoder_obj at instantiation time.

I just want to confirm that is the intended design.

@w4nderlust
Copy link
Collaborator

Yes it is intended. The reason is that before there was a mechanism in TF1 to specify the scope of tf variables that would be created, and users could specify those variables to be tied to variables initialuzed for another input. This is useful if i have for instance two pieces of text and i want to verify if they match or not, i want the two text encoders to have same structure and weights. Now in tf2, as the encoder is an object, it’s much easier to just pass the encoder object and that’s it.
Output features don’t need this mechanism, can’t find a reason why i would have the same decoder and thus produce the same predictions.

@jimthompson5802
Copy link
Collaborator Author

Status update

For the past couple of days, I've spent my time reviewing documentation and prototyping code related to sequence models:

  • Reviewed the short tutorial you provided on sequence models
  • Reviewed the Tensorflow addons (TFA) documentation and sample code. I specifically focused on this sample code because it involved training with a custom model and layers. This allowed me to understand how to partition relevant api calls between __init__() and call() methods and use of TFA's apis.
  • Reviewed code in NLP example located here.
  • Finally wrote prototype code that converted the Automatic Translation example in the preceding notebook (starting in code cell In [65]:) into a simplified Ludwig custom model and layer construct using TFA apis.

I'm planning to apply what I learned in this exercise to porting Ludwig's Sequence Decoders to utilize TFA. If you look at the code, which is provided at end, you may notice that some of the methods' return signatures differ from what we do in Ludwig. Right now I think I can still adhere to the current return signatures with use of dictionary data structures. At the momment I'm expecting no changes to Ludwig's internal processing. I'm planning to have any changes occur in the specific output feature or decoder class methods.

Let me know if I need to be aware of anything.

Sequence Encoder/Decoder Prototype

Notes:

  • MyModel class is the equivalent of Ludwig's ECD class
  • InputFeatureEncoder class represents Ludwig's input feature encoder
  • OutputFeatureDecoder class represents Ludwig's output feature decoder
from functools import reduce

import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.models import Model
import tensorflow_addons as tfa


tf.random.set_seed(42)
np.random.seed(42)

VOCAB_SIZE = 100
EMBED_SIZE = 10
RNN_UNITS = 512
INPUT_SEQUENCE_SIZE = 10
OUTPUT_SEQUENCE_SIZE = 15

# setup random training data
X = np.random.randint(100, size=INPUT_SEQUENCE_SIZE*1000).reshape(1000, INPUT_SEQUENCE_SIZE)
Y = np.random.randint(100, size=OUTPUT_SEQUENCE_SIZE*1000).reshape(1000, OUTPUT_SEQUENCE_SIZE)


# quick and dirty generator to return random batches from training data
def batcher(total_size, batch_size):
    i = 0
    idx = np.random.permutation(range(total_size))
    for i in range(0, total_size, batch_size):
        yield list(idx[i:(i + batch_size)])


# Custom layer encoder for input sequence feature
class InputFeatureEncoder(keras.layers.Layer):
    def __init__(self, vocab_size, embed_size, rnn_units):
        super(InputFeatureEncoder, self).__init__()

        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.rnn_units = rnn_units

        self.embeddings_enc = keras.layers.Embedding(self.vocab_size, self.embed_size)
        self.encoder = keras.layers.LSTM(self.rnn_units, return_state=True)

    def call(self, inputs, training=True, mask=None):
        encoder_embeddings = self.embeddings_enc(inputs)

        encoder_outputs, state_h, state_c = self.encoder(encoder_embeddings)
        encoder_state = [state_h, state_c]

        return encoder_outputs, encoder_state

# Custom layer decoder for output sequence feature
class OutputFeatureDecoder(keras.layers.Layer):
    def __init__(self, vocab_size, embed_size, rnn_units):
        super(OutputFeatureDecoder, self).__init__()

        self.embeddings_dec = keras.layers.Embedding(vocab_size, embed_size)
        self.sampler = tfa.seq2seq.sampler.TrainingSampler()
        self.decoder_cell = keras.layers.LSTMCell(rnn_units)
        self.output_layer = keras.layers.Dense(vocab_size)
        self.decoder = \
            tfa.seq2seq.basic_decoder.BasicDecoder(self.decoder_cell,
                                                    self.sampler,
                                                    output_layer=self.output_layer)

    def build_initial_state(self, batch_size, encoder_state=None):
        initial_state = self.decoder_cell.get_initial_state(
            inputs=encoder_state,
            batch_size=batch_size,
            dtype=tf.float32
        )
        return initial_state

    def build_sequence_lengths(self, batch_size):
        return np.ones((batch_size,)).astype(np.int32) * OUTPUT_SEQUENCE_SIZE

    def build_initial_state(self, batch_size, encoder_state):
        # currently just return the encoder_state for now
        # place holder if we have to do something more complicated
        return encoder_state


    def call(self, decoder_inputs, training=True, mask=None, encoder_end_state=None):
        #print(">>>debug", decoder_inputs.shape)
        decoder_embeddings = self.embeddings_dec(decoder_inputs)

        sequence_lengths = self.build_sequence_lengths(decoder_inputs.shape[0])

        initial_state = self.build_initial_state(decoder_inputs.shape[0],
                                                 encoder_end_state)

        final_outputs, final_state, final_sequence_lengths = self.decoder(
            decoder_embeddings, initial_state=initial_state,
            sequence_length=sequence_lengths)

        return final_outputs.rnn_output, final_outputs, final_state, final_sequence_lengths


tf.config.experimental_run_functions_eagerly(True)

# Custom model equivalent to Ludwig's ECD class
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()

        self.X_encoder = InputFeatureEncoder(VOCAB_SIZE, EMBED_SIZE, RNN_UNITS)

        self.Y_decoder = OutputFeatureDecoder(VOCAB_SIZE, EMBED_SIZE, RNN_UNITS)

        self.optimizer = tf.keras.optimizers.Adam()

        self.train_loss_metric = tf.keras.metrics.Mean(name='train_loss')

    def loss_function(self, y, y_pred):
        #shape of y [batch_size, output_sequence]
        #shape of y_pred [batch_size, output_sequence, output_vocab_size]
        sparsecategoricalcrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                                                      reduction='none')
        loss = sparsecategoricalcrossentropy(y_true=y, y_pred=y_pred)
        mask = tf.logical_not(tf.math.equal(y,0))   #output 0 for y=0 else output 1
        mask = tf.cast(mask, dtype=loss.dtype)
        loss = mask* loss
        loss = tf.reduce_mean(loss)
        return loss


    def call(self, inputs, training=True, intial_state=None):

        encoder_outputs, encoder_state = self.X_encoder(inputs)

        (
            logits,
            final_outputs,
            final_state,
            final_sequence_lengths
        ) = self.Y_decoder(
            encoder_outputs, # output from the encoder
            encoder_end_state=encoder_state  # end state from the encoder
        )

        return logits, final_outputs, final_state, final_sequence_lengths

    def reset_metrics(self):
        self.train_loss_metric.reset_states()

    # @tf.function
    def train_step(self,  inputs, targets):
        # #  issue?: https://github.com/tensorflow/tensorflow/issues/28901
        # y = y[:, tf.newaxis]
        with tf.GradientTape() as tape:
            (
                logits,
                final_output,
                final_state,
                final_sequence_length
            ) = self(inputs, training=True)
            # print("in training", y.shape, y_hat.shape)
            loss = self.loss_function(targets, logits)
            print("\tbatch training loss:", loss.numpy())
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.train_loss_metric.update_state(loss)


if __name__ == '__main__':
    model = MyModel()

    print("<<<<<< CUSTOM TRAINING LOOP>>>>>>>>>>")

    EPOCHS = 10

    for epoch in range(EPOCHS):
        # Reset the metrics at the start of the next epoch
        model.reset_metrics()
        print("starting epoch", epoch+1)
        for batch in batcher(1000, 128):
            model.train_step(X[batch], Y[batch])

        print("epoch loss metric",model.train_loss_metric.result().numpy(),'\n')

Prototype output

b03656f92c6d:python -u /opt/project/sandbox/tf2_port/sandbox_custom_sequence_training.py
2020-04-12 18:55:43.702378: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory
2020-04-12 18:55:43.702584: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory
2020-04-12 18:55:43.702697: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2020-04-12 18:55:44.745792: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-04-12 18:55:44.745842: E tensorflow/stream_executor/cuda/cuda_driver.cc:351] failed call to cuInit: UNKNOWN ERROR (303)
2020-04-12 18:55:44.745866: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (b03656f92c6d): /proc/driver/nvidia/version does not exist
2020-04-12 18:55:44.746070: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-04-12 18:55:44.753713: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2791455000 Hz
2020-04-12 18:55:44.754867: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x4e91b80 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-04-12 18:55:44.754912: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
<<<<<< CUSTOM TRAINING LOOP>>>>>>>>>>
starting epoch 1
	batch training loss: 4.554785
	batch training loss: 4.576291
	batch training loss: 4.5639863
	batch training loss: 4.5733886
	batch training loss: 4.5564466
	batch training loss: 4.5486298
	batch training loss: 4.577191
	batch training loss: 4.522045
epoch loss metric 4.5590954 

starting epoch 2
	batch training loss: 4.5548887
	batch training loss: 4.555095
	batch training loss: 4.538684
	batch training loss: 4.5719156
	batch training loss: 4.5566115
	batch training loss: 4.5438
	batch training loss: 4.547464
	batch training loss: 4.5730786
epoch loss metric 4.5551925 

starting epoch 3
	batch training loss: 4.542449
	batch training loss: 4.5552673
	batch training loss: 4.544885
	batch training loss: 4.5558276
	batch training loss: 4.549805
	batch training loss: 4.528891
	batch training loss: 4.5619445
	batch training loss: 4.57772
epoch loss metric 4.5520988 

starting epoch 4
	batch training loss: 4.543325
	batch training loss: 4.559398
	batch training loss: 4.537799
	batch training loss: 4.557769
	batch training loss: 4.548679
	batch training loss: 4.553115
	batch training loss: 4.554301
	batch training loss: 4.540425
epoch loss metric 4.5493517 

starting epoch 5
	batch training loss: 4.554665
	batch training loss: 4.541597
	batch training loss: 4.546907
	batch training loss: 4.5376945
	batch training loss: 4.5526795
	batch training loss: 4.5203147
	batch training loss: 4.5624995
	batch training loss: 4.5693374
epoch loss metric 4.548212 

starting epoch 6
	batch training loss: 4.5279803
	batch training loss: 4.5553308
	batch training loss: 4.5531473
	batch training loss: 4.522766
	batch training loss: 4.568819
	batch training loss: 4.542282
	batch training loss: 4.5514097
	batch training loss: 4.545376
epoch loss metric 4.545889 

starting epoch 7
	batch training loss: 4.5444236
	batch training loss: 4.5564017
	batch training loss: 4.5565968
	batch training loss: 4.5409017
	batch training loss: 4.541494
	batch training loss: 4.548808
	batch training loss: 4.5122046
	batch training loss: 4.561908
epoch loss metric 4.5453424 

starting epoch 8
	batch training loss: 4.543294
	batch training loss: 4.5538425
	batch training loss: 4.5422764
	batch training loss: 4.5456266
	batch training loss: 4.54426
	batch training loss: 4.5423083
	batch training loss: 4.531722
	batch training loss: 4.5423927
epoch loss metric 4.5432153 

starting epoch 9
	batch training loss: 4.540458
	batch training loss: 4.5446005
	batch training loss: 4.5407934
	batch training loss: 4.534824
	batch training loss: 4.5373025
	batch training loss: 4.553157
	batch training loss: 4.533079
	batch training loss: 4.553448
epoch loss metric 4.5422077 

starting epoch 10
	batch training loss: 4.5383096
	batch training loss: 4.543357
	batch training loss: 4.5497494
	batch training loss: 4.5236692
	batch training loss: 4.5414824
	batch training loss: 4.5261436
	batch training loss: 4.539192
	batch training loss: 4.553021
epoch loss metric 4.539366 


Process finished with exit code 0

@w4nderlust
Copy link
Collaborator

This all looks great! It can be a blueprint to use as a reference for how to implement things in Ludwig. In Ludwig there are more options obviously (attention beam search to name two important ones) but this is a really good starting point.
Let me comment on some points.

encoder_outputs, state_h, state_c = self.encoder(encoder_embeddings)
encoder_state = [state_h, state_c]

return encoder_outputs, encoder_state

Here things may be different, as we expect encoders right now to return just one tensor. We can extend what we did for output features prediction methods, and return a dictionary with keys instead, so that components down the road can use whatever they want from that dictionary. Actually the previous version of Ludwig was already returning dictionaries, so most components should already work fine if we decide to to go this route.

self.embeddings_dec = keras.layers.Embedding(vocab_size, embed_size)
self.output_layer = keras.layers.Dense(vocab_size)

In Ludwig there's already a mechanism to allow to reuse the embeddings in the encoder, so that output_layer.weights == embeddings_dev.T. This is not urgent nor extremely important at the moment, but let's keep it in mind.

def build_initial_state(self, batch_size, encoder_state):
    # currently just return the encoder_state for now
    # place holder if we have to do something more complicated
    return encoder_state

As in Ludwig the decoder is agnostic of the encoder, it may be the case that there is no encoder state or that the encoder state is not of the size required. In the first case, one could either initialize the state to be all zeros, or one could initialize it to be a vector of weights that are learned. Literature suggest the latter to work better. We don't have to implement it from the get go, bu again, something to keep in mind.

Another consideration, your example is an example of what in Ludwig is called generator decoder. There's also the tagger decoder (which is simpler) that has a different structure, makes different assumptions on the inputs and provides outputs in a different way. Let me know if you want me to explain how that works, but again, it's much easier than the generator decoder, and once we make the generator work, the tagger will be much simpler.

Let me know if you have any doubt and I'll be happy to clarify.

@jimthompson5802
Copy link
Collaborator Author

As in Ludwig the decoder is agnostic of the encoder, it may be the case that there is no encoder state or that the encoder state is not of the size required. In the first case, one could either initialize the state to be all zeros, or one could initialize it to be a vector of weights that are learned

Glad you pointed this out. I was looking at the decoder code and noticed there did not appear to be a link between encoder and decoder states. This saved me asking a follow-up question. For now I'll plan to user all zero initial states, it's the easiest to implement.

Here things may be different, as we expect encoders right now to return just one tensor. We can extend what we did for output features prediction methods, and return a dictionary with keys instead, so that components down the road can use whatever they want from that dictionary.

I was thinking of using a dictionary as you described. So I'll proceed with this direction.

There's also the tagger decoder (which is simpler) that has a different structure, makes different assumptions on the inputs and provides outputs in a different way. Let me know if you want me to explain how that works, but again, it's much easier than the generator decoder,

If tagger is simpler, I start with that. If I have questions, I'll reach back out.

In a reply from several weeks ago, you mentioned starting to port code to use tfa apis. How comfortable are you with that code?

@w4nderlust
Copy link
Collaborator

I was thinking of using a dictionary as you described. So I'll proceed with this direction.

I would assume that every encoder return a dictionary with at least a encoder_output key, so that all components down the like can assume that is provided, but, if needed, we can extend to provide both reduced and non reduced outputs (which will be really useful when adding attention) and in some cases also provide references other properties of the encoder that could be useful for downstream components to make some choices.

If tagger is simpler, I start with that. If I have questions, I'll reach back out.

Yes, basically is assumes that the input is a sequence, so that the shape is [batch, sequence_length, hidden], and produces a classification for each element of the sequence, so the output logits have a [batch,sequence_length, num_classes] shape (num_classes is the vocabulary of the sequence feature).

In a reply from several weeks ago, you mentioned starting to port code to use tfa apis. How comfortable are you with that code?

Don't remember exactly the context, in general I'm not familiar with tfa, I was "familiar" with the sequence decoding functions in tf_contrib that are the ones that have been ported to tfa, but `. from our previous attempt at sequence decoder, it looked like they changed some stuff and 2. those classes and methods were really poorly documented to begin with, may things I discovered, in particular regarding the beam search, I did by looking at the source.

A final consideration: it could be the case that https://www.tensorflow.org/tutorials/text/text_generation could be a valuable example, it does not use anything from tfa. The problem is that it will not heave the bem search capabilities etc. but it's good to keep as a reference.

@jimthompson5802
Copy link
Collaborator Author

jimthompson5802 commented Apr 13, 2020

re:

I would assume that every encoder return a dictionary with at least a encoder_output key, so that all components down the like can assume that is provided, but, if needed, we can extend to provide both reduced and non reduced outputs

Just to confirm my understanding. https://github.com/uber/ludwig/blob/212122c87930277ff6ef055dac11f9e694de53d8/ludwig/models/ecd.py#L53-L58
Currently encoder_output returned by the encoder() is a Tensor

With this change the return from encoder(): encoder_output is now structured as follows

encoder_output = {
    'encoder_output': Tensor
}

Would you like me to take a slight pause on the sequence work and submit a new PR to implement the 'encoder return a dictionary' for the current Numerical, Binary and Categorical encoders to support this dictionary structure? This way if you have other folks working on TF2 porting this should establish the a baseline for that work.

Or just include these changes in this PR on the Sequence Feature?

@w4nderlust
Copy link
Collaborator

I think it's more flexible, so yes, I would say it would be a good pattern to follow. The concat combiner would need to be changed accordingly too (and all other combiners, but that will be done later).

@jimthompson5802
Copy link
Collaborator Author

This commit 2ebf5a5 implements use of encoder_output key, e.g.,

encoder_output = {
    'encoder_output': Tensor
}

Numerical, Binary and Categorical encoders and ConcatCombiner class updated to reflect the change.

I used the recently added unit test tests/integration_tests/test_simple_feature.py to validate the change:

root@c3b25e4bcbc4:/opt/project# pytest tests/integration_tests/test_simple_features.py
=========================================================== test session starts ============================================================
platform linux -- Python 3.6.9, pytest-5.4.1, py-1.8.1, pluggy-0.13.1
rootdir: /opt/project
plugins: typeguard-2.7.1
collected 10 items

tests/integration_tests/test_simple_features.py ..........                                                                           [100%]

============================================================= warnings summary =============================================================
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/pywrap_tensorflow_internal.py:15
  /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/pywrap_tensorflow_internal.py:15: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
    import imp

tests/integration_tests/test_simple_features.py::test_feature[input_test_feature3-output_test_feature3-None]
tests/integration_tests/test_simple_features.py::test_feature[input_test_feature4-output_test_feature4-None]
tests/integration_tests/test_simple_features.py::test_feature[input_test_feature5-output_test_feature5-output_loss_parameter5]
tests/integration_tests/test_simple_features.py::test_feature[input_test_feature6-output_test_feature6-output_loss_parameter6]
tests/integration_tests/test_simple_features.py::test_feature[input_test_feature7-output_test_feature7-output_loss_parameter7]
tests/integration_tests/test_simple_features.py::test_feature[input_test_feature8-output_test_feature8-output_loss_parameter8]
tests/integration_tests/test_simple_features.py::test_feature[input_test_feature9-output_test_feature9-output_loss_parameter9]
  /usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
    _warn_prf(average, modifier, msg_start, len(result))

-- Docs: https://docs.pytest.org/en/latest/warnings.html
===================================================== 10 passed, 8 warnings in 11.05s ======================================================
root@c3b25e4bcbc4:/opt/project#

@w4nderlust
Copy link
Collaborator

Looks good. Can I ask you to actually create another pr just for this commit? And also add the same thing for image features? This way other people working on the branch will have this updated pattern.

@jimthompson5802
Copy link
Collaborator Author

jimthompson5802 commented Apr 19, 2020

Maybe one question.is "accuracy" in the table below "row-wise accuracy"?

╒══════════╤════════╤════════════╤══════════════════╤═════════════════╤══════════════╤═════════════════╕
│ y        │   loss │   accuracy │   token_accuracy │   last_accuracy │   perplexity │   edit_distance │ 
╞══════════╪════════╪════════════╪══════════════════╪═════════════════╪══════════════╪═════════════════╡
│ training │ 2.2923 │     0.0000 │           0.1374 │          0.1756 │       9.8978 │          0.8465 │
├──────────┼────────┼────────────┼──────────────────┼─────────────────┼──────────────┼─────────────────┤
│ vali     │ 2.2714 │     0.0000 │           0.1429 │          0.2247 │       9.6934 │          0.8411 │
├──────────┼────────┼────────────┼──────────────────┼─────────────────┼──────────────┼─────────────────┤
│ test     │ 2.2819 │     0.0000 │           0.1479 │          0.2111 │       9.7951 │          0.8392 │
╘══════════╧════════╧════════════╧══════════════════╧═════════════════╧══════════════╧═════════════════╛

@w4nderlust
Copy link
Collaborator

Yes.

This requires another little bit of explanation (I guess eventually we can collect all those explanations and put them in the developer's guide :) ).
If you look at the results there is always an additional special output feature called combined. When there's only one output feature it is not printed. It exists because single output features may have their own losses and metrics, but if you train on more than one at the same time, you want to have an understanding if you are doing progress overall.
The combined loss reported is the weighted sum of each independent loss plus the regularization and any other auxiliary loss, and it is actually the value the optimizer receives as input.
But another value is also reported, the combined accuracy. This is a bit of a problem, and a thing I want to change, as, for instance, if I'm optimizing for a numerical and a category output feature at the same time, the combined accuracy doesn't make much sense because the numerical does not return any accuracy... so that's something I have to fix, but assuming all output features have an accuracy metric, the combined metric is there to give an indication of the portion of datapoints where all my output features are returning the correct predictions AT THE SAME time. So assuming I have two vectors that return 1 or 0 if the predictions are correct, which are obtained for instance by y == y_pred for a batch, with a batch size of, say, 4, the combined accuracy is calculated this way:
output_feature_1_correct_prediction = y_of_1 == y_pred_of_1 = [1, 1, 1, 0]
output_feature_2_correct_prediction = y_of_2 == y_pred_of_2 = [1, 0, 0, 1]
combined_accuracy = sum( [1, 1, 1, 0] and [1, 0, 0, 1]) / len = sum([1, 0, 0, 0]) / 4 = 0.25

Now, the way I previously implemented this mechanism was to collect these vectors of 1s and 0s for each batch, concatenate them for the length of the whole set, perform the and operation with these vectors from all the output features, and finally summing and dividing. That's why in the sequence feature you see those CORRECT_*_PREDICTIONS alongside the accuracy values, as the accuracy values are averaged by batch, those correct predictions values are not.
Now, this introduces quite some complexity, so for the time being let's ditch it, let's just keep the combined loss, but in the future we may want to find a clean way to reintroduce this kind of functionality, as it is very useful in particular for validation logic purposes (like deciding when to reduce the learning rate or when to do early stopping).

Finally, the reason why the name accuracy in the sequence feature was assigned to rowwise_accuracy is that the combined accuracy needed a 1 / 0 kind of value to perform the logical and (even if it was implemented as product) to be meaningful, and the rowwise accuracy gives you that, plus it is the most restrictive of the 3 losses, the hardest one to obtain high scores, so it's also the most conservative option.

@jimthompson5802
Copy link
Collaborator Author

This is what I understand: "accuracy" in the table is the concept of "combined accuracy". For the moment, this is out of scope.

Of the remaining values in the table, we have "loss" (or combined loss). I pushed commits for "last_accuracy" and "preplexity".

I'm now going to focus on "token accuracy" and "edit distance".

I was looking at the function masked_accuracy(). I'm using this as the model for implementing for 'token accuracy' and 'row-wise accuracy' . From what I can tell, the "mask" in this context is to eliminate from consideration <PAD> characters that come into play with variable length sequences.

Right now the Ludwig data generator for sequences always returns a sequence of max_len. So all my testing right now involves sequences of fixed length.

In looking at tests/utils.py and ludwig/data/dataset_synthesyzer.py, it appears that I can generate variable length sequences if I add the 'min_len key to the dictionary that is returned by sequence_feature() function in tests/integration_test/utils.py.

Am I correct in my understanding in how I could generate variable length sequences?

If this is true, then I can be more robust in my testing and we can use this to enhance the unit tests for Ludwig.

feat: add determination for tf dtype on predictions
refactor: api fix up and code clean-up
@jimthompson5802
Copy link
Collaborator Author

This commit de7b133 contained these changes:

  • added class EditDistanceMetric to metric_modules.py. In making this addition, I used the original TF1 edit_distance() function with some slight modifications.
  • incorporate most of your last set of comments, there was one I had a follow-up question.
  • regarding this observation in TF2 porting: sequence feature #682 (comment)

this may be int64 in some cases (when the vocabulary is bigger than a certain amount.

This is because when there's a small vocabulary the ints may be int8 or int16, something I did to save both disk space and memory.

I added some code that could be used as a basis to handle this kind of situation in the SequenceOutputFeature.__init()` method:

        # determine required tf dtype to represent sequence encoded values
        max_base2_exponent = np.int(np.ceil(np.log2(self.num_classes - 1)))
        if max_base2_exponent <= 32:
            self._prediction_dtype = tf.int32
        else:
            self._prediction_dtype = tf.int64

We need to use either tf.int32 or tf.int64 because other int types, e.g., tf.int8, tf.int16 causes an error in the use of tf.argmax() function.

then in the predictions() method

        predictions = tf.argmax(
            logits,
            -1,
            name='predictions_{}'.format(self.name),
            output_type=self._prediction_dtype
        )

After completing this work, I thought more about it. Do you think num_classes will ever get larger than 2^32? 2^32 is over 2 billion unique values. It might be reasonable to just specify tf.int32.

Next drop should include the remaining accuracy metrics.

@w4nderlust
Copy link
Collaborator

This is what I understand: "accuracy" in the table is the concept of "combined accuracy". For the moment, this is out of scope.

Accuracy in the table of a sequence feature is row-wise accuracy.
Accuracy in the table of combined feature is what I described above.

I'm now going to focus on "token accuracy" and "edit distance".

let's keep in mind we want to reintroduce them later on.

I was looking at the function masked_accuracy(). I'm using this as the model for implementing for 'token accuracy' and 'row-wise accuracy' . From what I can tell, the "mask" in this context is to eliminate from consideration <PAD> characters that come into play with variable length sequences.

Correct.

Right now the Ludwig data generator for sequences always returns a sequence of max_len. So all my testing right now involves sequences of fixed length.

I believe this is incorrect. The logic of the contrib Decoder I was using in TF1 is that it keeps on generating until an EOS (end of sequence) symbol is generated or max_len is reached.
As sequences are generated in batch, the whole process keeps on generating until at least one of the elements of the batch hasn't generated EOS or reached max_len. The generated tokens for the elements of the batch that already reached EOS is 0 (PAD).
For instance:
batch = 3
max_len = 6
one possible generation when max_len is reached by one of the elements of the batch is:

[[GO, 3, 5, EOS, 0, 0],
 [GO, 2, 3, 3, 5, 6],
 [GO, 4, EOS, 0, 0, 0]]

If all elements of the batch reach EOS before max_len the second dimension of the tensor is smaller:

[[GO, 3, 5, EOS],
 [GO, 2, EOS, 0],
 [GO, 4, EOS, 0]]

In the code you'll find various pieces where I do manipulation and fill zeros to make sure the length of true sequences and predicted sequences is the same, and then I do masking on the loss to avoid considering PAD / 0s. This happens in seq2seq_sequence_loss and sequence_sampled_softmax_cross_entropy for example.

In looking at tests/utils.py and ludwig/data/dataset_synthesyzer.py, it appears that I can generate variable length sequences if I add the 'min_len key to the dictionary that is returned by sequence_feature() function in tests/integration_test/utils.py.
Am I correct in my understanding in how I could generate variable length sequences?

Yes.

If this is true, then I can be more robust in my testing and we can use this to enhance the unit tests for Ludwig.

That is fine, but consider there are a bunch of tricky things here. In particular, the tagger decoder and the generator decode behave quite differently in this regard.

The tagger expects a embedded sequence as input of a fixed length including padding 0 vectors), produces a sequence of fixed length (including padding 0s and compares it with the ground truth sequence of the same fixed length (including padding 0s) to obtain a loss for each element of the sequence and then masks this vector of losses figuring out where the padding is. There is no EOS concept here.

The generator on the other hand can have all sorts of lengths for reasons I explained before, plus has also GO and EOS symbols before and after the sequence. EOS is important to have as if it is not included in the loss the model does not learn to stop generating.

So to answer your questions, you should test tagger with sequences of exactly the same length, we should actually add a value error / assert in the tagger if the length of inputs and ground truth don't match. For testing the generator, yeas you want to have sequences of different lengths.

@w4nderlust
Copy link
Collaborator

I added some code that could be used as a basis to handle this kind of situation in the SequenceOutputFeature.__init()` method:

        # determine required tf dtype to represent sequence encoded values
        max_base2_exponent = np.int(np.ceil(np.log2(self.num_classes - 1)))
        if max_base2_exponent <= 32:
            self._prediction_dtype = tf.int32
        else:
            self._prediction_dtype = tf.int64

There is already a function in Ludwig to do that: utils.math_utils.int_type. It returns np types, but could be modified to return tf types instead9with a parameter like tf or np).

We need to use either tf.int32 or tf.int64 because other int types, e.g., tf.int8, tf.int16 causes an error in the use of tf.argmax() function.

Got it. Now we also have to make sure that when we are calculating the metrics / losses the datatype makes no difference. In the docs for Accuracy.update_state() they don't specify the type https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Accuracy#args_4 but we should make sure that any combination of y and y pred int types works, or otherwise upcast everything to int64.

After completing this work, I thought more about it. Do you think num_classes will ever get larger than 2^32? 2^32 is over 2 billion unique values. It might be reasonable to just specify tf.int32.

Unfortunately it can actually happen. Thin about classifying (or embedding, on the converse, for the encoder size) users of a social network platform. I actually had todeal with something that was on the verge of int32 already. Also, being this only the prediction, it's not a huge cost to pay in therms of memory really. (the rank is just rank 2, not a rank 3 tensor, which would be much more expensive).

@jimthompson5802
Copy link
Collaborator Author

from this guidance

we should make sure that any combination of y and y pred int types works, or otherwise upcast everything to int64.

I going to follow approach. It is simpler. This commit def27bd does this.

re: sequence lengths...in the earlier posting, I wanted to point out that in TF2 determination of sequence lengths is occurring in different components than TF1. In TF1 the sequence lengths are computed in the sequence decoder. In TF2, sequence lengths will probably have to be computed in several different places, i.e., loss and metric functions.

For now I'll focus on finishing out computing all the required metrics with the fixed length sequences. We can come back later to deal with variable length sequences.

@w4nderlust
Copy link
Collaborator

w4nderlust commented Apr 20, 2020

I going to follow approach. It is simpler. This commit def27bd does this.

Sounds good. Let's make sure the same thing happens in the categorical feature too.

re: sequence lengths...in the earlier posting, I wanted to point out that in TF2 determination of sequence lengths is occurring in different components than TF1. In TF1 the sequence lengths are computed in the sequence decoder. In TF2, sequence lengths will probably have to be computed in several different places, i.e., loss and metric functions.

For now I'll focus on finishing out computing all the required metrics with the fixed length sequences. We can come back later to deal with variable length sequences.

That's fine. I believe we can have an additional tensor returned by the predict function, that is a length sequence tensor, for shape [b,] and type int, so that we compute the length one time and it is available to all losses and metrics. And can also be used during postprocessing to trim the sequences to the right length.

@w4nderlust
Copy link
Collaborator

I think we can merge what we have here as the sequence input feature already works with the embed encoder, I just ported the parallelcnn encoder too (yet to be thoroughly tested). and if we merge I can work on the other encoders while you can keep working on the decoders.

@jimthompson5802
Copy link
Collaborator Author

Sounds good

@w4nderlust w4nderlust merged commit ff51198 into ludwig-ai:tf2_porting Apr 20, 2020
@jimthompson5802 jimthompson5802 deleted the tf2_sequence_feature branch April 20, 2020 04:44
@w4nderlust
Copy link
Collaborator

An update to the sequence feature porting. I ported all the encoders and cleaned a lot the code underneath. The encoding part of the sequence features can be considered done.

@jimthompson5802
Copy link
Collaborator Author

Great...I've been looking at the next set of work for decoders. From what I can tell, I have to retrofit Attention back into the Tagger decoder. Right now I'm looking over the TFA attention apis and trying to map out how to implement them. I may have few questions in the next day or two.

Once I'm done with that, then convert the Generator decoder.

Once I have something to post, I'll create a new PR.

Making progress

@w4nderlust
Copy link
Collaborator

Attention is not needed for the Tagger anyway, only needed for the generator. Again, the tagger should be very simple to deal with, the generator is the more complicated one.

@jimthompson5802
Copy link
Collaborator Author

OK. Before moving on to the generator decoder, I just want to make sure I accounted for functionality that you want in Tagger.

The reason I asked about attention is that I saw this code when I started the work:
https://github.com/uber/ludwig/blob/a9b80542dcc47dbf85a8308f69cd67eaa707ca08/ludwig/models/modules/sequence_decoders.py#L171-L176
At this time I have not done anything explicit to handle the functionality encompassed by feed_forward_memory_attention().

At this time, these I have not explicitly ported this over as well:
https://github.com/uber/ludwig/blob/a9b80542dcc47dbf85a8308f69cd67eaa707ca08/ludwig/models/modules/sequence_decoders.py#L163-L167

Finally, this Tagger functionality was replaced by a single Dense() layer, which basically translates the the decoder input [b, s, h] to [b,s,c]
https://github.com/uber/ludwig/blob/a9b80542dcc47dbf85a8308f69cd67eaa707ca08/ludwig/models/modules/sequence_decoders.py#L179-L201
Right now the Tagger decoder is basically only the Dense() layer.

if this correct, I'll start looking at the Generator decoder.

@w4nderlust
Copy link
Collaborator

The reason I asked about attention is that I saw this code when I started the work:

https://github.com/uber/ludwig/blob/a9b80542dcc47dbf85a8308f69cd67eaa707ca08/ludwig/models/modules/sequence_decoders.py#L171-L176

At this time I have not done anything explicit to handle the functionality encompassed by feed_forward_memory_attention().

Oh I see now, completely forgot about it. Was an experimental thing anyway, I plan to reintroduce it in the futur, please add a todo here saying "todo tf2 add feed forward attention".
In the meantime you can comment it out or use this: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention
you would first initialize it in the constructor and then call it like:

hidden = self.attention_layer([hidden, hidden])

where hidden is a [h, s, h] tensor, so before the Dense layer.

At this time, these I have not explicitly ported this over as well:

https://github.com/uber/ludwig/blob/a9b80542dcc47dbf85a8308f69cd67eaa707ca08/ludwig/models/modules/sequence_decoders.py#L163-L167

For the timeseries, it's very simple, just the dense output dimension should be 1, no more than that.
For the regularizer, you can pass the parameter to the Dense construtor in __init__, there's no regularizer anymore in the call function, just inputs, training, mask.

Finally, this Tagger functionality was replaced by a single Dense() layer, which basically translates the the decoder input [b, s, h] to [b,s,c]

https://github.com/uber/ludwig/blob/a9b80542dcc47dbf85a8308f69cd67eaa707ca08/ludwig/models/modules/sequence_decoders.py#L179-L201

Right now the Tagger decoder is basically only the Dense() layer.

That's correct., as we discussed, under the hood what happens is reshape-matmul-reshape, which is slightly more complex than what happens in the category classifier decoder case, but Dense actually implements both the functionalities, so we are good with that.

if this correct, I'll start looking at the Generator decoder.

Sounds great! if you want you can do a PR just for the Tagger stuff and then do another one for the Generator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants