In [1]:
"""
The MIT License (MIT)
Copyright (c) 2021 NVIDIA
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""


'\nThe MIT License (MIT)\nCopyright (c) 2021 NVIDIA\nPermission is hereby granted, free of charge, to any person obtaining a copy of\nthis software and associated documentation files (the "Software"), to deal in\nthe Software without restriction, including without limitation the rights to\nuse, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of\nthe Software, and to permit persons to whom the Software is furnished to do so,\nsubject to the following conditions:\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\nTHE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS\nFOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR\nCOPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER\nIN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OU

This code example extends the multimodal network from c17e2_multi_modal with an additional head to build a network that does multitask learning using multimodal inputs. We teach the network to simultaneously do multiclass classification (identify the handwritten digit) and perform a simple question-answering task. The question-answering task is to provide a yes/no answer to a question about the digit in the image. The textual input will look similar to the textual input in c17e2_multi_modal ('upper half', 'lower half', 'odd number', 'even number'). However, instead of correctly describing the digit, the text is chosen randomly and represents a question. The network is then tasked with classifying the image into one of ten classes as well as with determining whether the answer to the question is yes or no (is the statement true or false). More context for this code example can be found in the section "Programming Example: Multiclass classification and question answering with a single network" in Chapter 17 in the book Learning Deep Learning by Magnus Ekman (ISBN: 9780137470358).

As always, we start with initialization code and loading the dataset.


In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.text \
    import text_to_word_sequence
from tensorflow.keras.preprocessing.sequence \
    import pad_sequences
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model
import numpy as np
import logging
tf.get_logger().setLevel(logging.ERROR)

EPOCHS = 20
MAX_WORDS = 8
EMBEDDING_WIDTH = 4

# Load training and test datasets.
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images,
                               test_labels) = mnist.load_data()

# Standardize the data.
mean = np.mean(train_images)
stddev = np.std(train_images)
train_images = (train_images - mean) / stddev
test_images = (test_images - mean) / stddev


The next step is to extend the MNIST dataset with questions and answers. This is done in the next code snippet. The code alternates between the four questions/statements for each training and test example. It then determines whether the answer is yes or no based on the ground truth label.


In [3]:
# Function to create question and answer text.
def create_question_answer(tokenizer, labels):
    text = []
    answers = np.zeros(len(labels))
    for i, label in enumerate(labels):
        question_num = i % 4
        if question_num == 0:
            text.append('lower half')
            if label < 5:
                answers[i] = 1.0
        elif question_num == 1:
            text.append('upper half')
            if label >= 5:
                answers[i] = 1.0
        elif question_num == 2:
            text.append('even number')
            if label % 2 == 0:
                answers[i] = 1.0
        elif question_num == 3:
            text.append('odd number')
            if label % 2 == 1:
                answers[i] = 1.0
    text = tokenizer.texts_to_sequences(text)
    text = pad_sequences(text)
    return text, answers

# Create second modality for training and test set.
vocabulary = ['lower', 'upper', 'half', 'even', 'odd', 'number']
tokenizer = Tokenizer(num_words=MAX_WORDS)
tokenizer.fit_on_texts(vocabulary)
train_text, train_answers = create_question_answer(tokenizer,
                                                   train_labels)
test_text, test_answers = create_question_answer(tokenizer,
                                                 test_labels)


The next code snippet creates the network. Most of the network is identical to the programming example for the multimodal network. The key difference is that in parallel with the ten-unit output layer for multiclass classification, there is a one-unit output layer for binary classification. Given that there are two separate outputs, we also need to supply two separate loss functions. In addition, we supply weights for these two loss functions to indicate how to weigh the two into a single loss function for training the network. The weights should be treated like any other hyperparameter. A reasonable starting point is to have the same weight for both losses, so we use 50/50. Finally, when calling the fit method, we must provide ground truth for both heads of the model.


In [5]:
# Create model with functional API.
image_input = Input(shape=(28, 28))
text_input = Input(shape=(2, ))

# Declare layers.
embedding_layer = Embedding(output_dim=EMBEDDING_WIDTH,
                            input_dim = MAX_WORDS)
lstm_layer = LSTM(8)
flatten_layer = Flatten()
concat_layer = Concatenate()
dense_layer = Dense(25,activation='relu')
class_output_layer = Dense(10, activation='softmax')
answer_output_layer = Dense(1, activation='sigmoid')

# Connect layers.
embedding_output = embedding_layer(text_input)
lstm_output = lstm_layer(embedding_output)
flatten_output = flatten_layer(image_input)
concat_output = concat_layer([lstm_output, flatten_output])
dense_output = dense_layer(concat_output)
class_outputs = class_output_layer(dense_output)
answer_outputs = answer_output_layer(dense_output)

# Build and train model.
model = Model([image_input, text_input], [class_outputs,
                                          answer_outputs])
model.compile(loss=['sparse_categorical_crossentropy',
                    'binary_crossentropy'], optimizer='adam',
                    metrics=['accuracy'],
                    loss_weights = [0.5, 0.6])
model.summary()
history = model.fit([train_images, train_text],
                    [train_labels, train_answers],
                    validation_data=([test_images, test_text],
                    [test_labels, test_answers]), epochs=EPOCHS,
                    batch_size=64, verbose=2, shuffle=True)


Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 2)]          0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 2, 4)         32          input_4[0][0]                    
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None, 28, 28)]     0                                            
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, 8)            416         embedding_1[0][0]                
____________________________________________________________________________________________

2021-10-20 17:45:06.541223: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-10-20 17:45:06.652808: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-10-20 17:45:06.748349: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-10-20 17:45:20.171383: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-10-20 17:45:20.224511: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


938/938 - 15s - loss: 0.6745 - dense_4_loss: 0.4226 - dense_5_loss: 0.6617 - dense_4_accuracy: 0.8736 - dense_5_accuracy: 0.5912 - val_loss: 0.6709 - val_dense_4_loss: 0.2629 - val_dense_5_loss: 0.7706 - val_dense_4_accuracy: 0.9245 - val_dense_5_accuracy: 0.4933
Epoch 2/20
938/938 - 14s - loss: 0.5583 - dense_4_loss: 0.2358 - dense_5_loss: 0.6290 - dense_4_accuracy: 0.9312 - dense_5_accuracy: 0.6272 - val_loss: 0.6574 - val_dense_4_loss: 0.2209 - val_dense_5_loss: 0.7813 - val_dense_4_accuracy: 0.9362 - val_dense_5_accuracy: 0.5111
Epoch 3/20
938/938 - 15s - loss: 0.5118 - dense_4_loss: 0.2016 - dense_5_loss: 0.5872 - dense_4_accuracy: 0.9404 - dense_5_accuracy: 0.6747 - val_loss: 0.5927 - val_dense_4_loss: 0.1995 - val_dense_5_loss: 0.7041 - val_dense_4_accuracy: 0.9417 - val_dense_5_accuracy: 0.5830
Epoch 4/20
938/938 - 14s - loss: 0.4513 - dense_4_loss: 0.1846 - dense_5_loss: 0.5128 - dense_4_accuracy: 0.9455 - dense_5_accuracy: 0.7365 - val_loss: 0.5662 - val_dense_4_loss: 0.2077 