In [1]:

import numpy as np
import pandas as pd
from tensorflow import keras

In [2]:

# The Keras Sequential API makes strong assumptions about the model to be built
# Sometimes the Sequential API is not flexible enough, e.g. if you want to 
# jointly train two models with different architectures, you'll require the flexible functional API
m1_inputs = keras.Input(shape=(784,))
m2_inputs = keras.Input(shape=(784,))

In [3]:

# Train model 1
# Note that we pass the input layer to layer 1, and the output of layer 1 to layer 2
m1_layer1 = keras.layers.Dense(12, activation='sigmoid')(m1_inputs)
m1_layer2 = keras.layers.Dense(4, activation='softmax')(m1_layer1)

In [4]:

# Train model 2
# Note that we pass the input layer to layer 1, and the output of layer 1 to layer 2
# Also note that the model's activation function is different from that of model 1
m2_layer1 = keras.layers.Dense(12, activation='relu')(m2_inputs)
m2_layer2 = keras.layers.Dense(4, activation='softmax')(m2_layer1)

In [5]:

# Merge model outputs and define a functional model
merged = keras.layers.add([m1_layer2, m2_layer2])
model = keras.Model(inputs=[m1_inputs, m2_inputs], outputs=merged)

In [6]:
print(model.summary())

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 784)]        0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 784)]        0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 12)           9420        input_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 12)           9420        input_2[0][0]                    
______________________________________________________________________________________________