The Keras functional API is a way to create models that are more flexible than the tf.keras.Sequential API. The functional API can handle models with non-linear topology, shared layers, and even multiple inputs or outputs.

The main idea is that a deep learning model is usually a directed acyclic graph (DAG) of layers. So the functional API is a way to build graphs of layers.

In [1]:
import numpy as np
import tensorflow as tf

The batch size is always omitted since only the shape of each sample is specified.

In [2]:
inputs = tf.keras.Input(shape = (784,)) 

In [3]:
inputs.dtype

tf.float32

In [4]:
inputs.shape

TensorShape([None, 784])

In [5]:
dense = tf.keras.layers.Dense(64, activation = 'relu')

In [6]:
x = dense(inputs)

The "layer call" action is like drawing an arrow from "inputs" to this layer you created. You're "passing" the inputs to the dense layer, and you get x as the output.

In [7]:
dense_2 = tf.keras.layers.Dense(64, activation = 'relu')

In [8]:
y = dense_2(x)

In [9]:
dense_3 =  tf.keras.layers.Dense(10, activation = 'softmax')

In [10]:
outputs = dense_3(y)

Create a Model by specifying its inputs and outputs in the graph of layers:

In [11]:
model = tf.keras.Model(inputs = inputs , outputs = outputs, name = "new")

In [12]:
model.summary()

Model: "new"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 784)]             0         
_________________________________________________________________
dense (Dense)                (None, 64)                50240     
_________________________________________________________________
dense_1 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650       
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________


In [13]:
tf.keras.utils.plot_model(model, "new.png")

('Failed to import pydot. You must `pip install pydot` and install graphviz (https://graphviz.gitlab.io/download/), ', 'for `pydotprint` to work.')


In [14]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

In [None]:
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

In [None]:
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.RMSprop(),
    metrics=["accuracy"],
)

In [None]:
history = model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.2)

In [None]:
test_scores = model.evaluate(x_test, y_test, verbose=2)
print("Test loss:", test_scores[0])
print("Test accuracy:", test_scores[1])