In [1]:
import tensorflow as tf
from tensorflow import keras
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

### Using Subclassing API to build Dynamic Models

The Subclassing API allows us to use an imperative programming style

In [2]:
housing = fetch_california_housing()
X_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target)
X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full)

In [18]:
class WideAndDeepModel(keras.Model):
    def __init__(self, units=2, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.hidden1 = keras.layers.Dense(units, activation=activation)
        self.hidden2 = keras.layers.Dense(units, activation=activation)
        self.main_output = keras.layers.Dense(1)
        self.aux_output = keras.layers.Dense(1)
        
    def call(self, inputs):
        input_A, input_B = inputs
        hidden1 = self.hidden1(input_B)
        hidden2 = self.hidden2(hidden1)
        concat = keras.layers.concatenate([input_A, hidden2])
        main_output = self.main_output(concat)
        aux_output = self.aux_output(hidden2)
        return main_output, aux_output

This example looks very much like the functional API, except we do not need to create the inputs, we just use the input argument to the call() method, and we separate the creation of the layers in the constructor from their usage in hte call() method.
The big difference is that you can do pretty much anything you eant in the call() method: for loops, if statements, low-level TF operations,...

In [19]:
model = WideAndDeepModel()
model.compile(loss="mse", optimizer=keras.optimizers.SGD(lr=1e-3))

In [20]:
X_train_A, X_train_B = X_train[:, :5], X_train[:, 2:]
X_valid_A, X_valid_B = X_valid[:, :5], X_valid[:, 2:]
X_test_A, X_test_B = X_test[:, :5], X_test[:, 2:]
X_new_A, X_new_B = X_test_A[:3], X_test_B[:, :3]

In [22]:
history = model.fit(
    [X_train_A, X_train_B], [y_train, y_train],
    epochs=3,
    validation_data=([X_valid_A, X_valid_B], [y_valid, y_valid])
)

Train on 11610 samples, validate on 3870 samples
Epoch 1/3
Epoch 2/3
Epoch 3/3
