# Model Splitting: Creating Nested Models While Maintaining Functionality


In this tutorial, we will demonstrate how to split an existing Keras model into a sequence of nested models. The goal is to preserve the same underlying function of the original model but restructure it into smaller, modular components for easier inspection or experimentation.

## Step 1: Setting Up the Environment

If you're running this tutorial on **Google Colab**, follow these steps to install the required libraries and dependencies:

In [1]:
# On Colab: install the library
on_colab = "google.colab" in str(get_ipython())
if on_colab:
    import sys  # noqa: avoid having this import removed by pycln

    # install dev version for dev doc, or release version for release doc
    !{sys.executable} -m pip install -U pip
    !{sys.executable} -m pip install git+https://github.com/ducoffeM/keras_custom@main#egg=decomon
    # install desired backend (by default torch)
    !{sys.executable} -m pip install "torch"
    !{sys.executable} -m pip install "keras"

    # extra librabry used in this notebook
    !{sys.executable} -m pip install "numpy"
    # missing imports IPython

## Step 2: Import Required Libraries
Next, we import the necessary libraries for our model and image processing.

In [None]:
import keras
from keras.applications.resnet50 import ResNet50
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np

import keras.backend as K
from keras.models import Model, Sequential
from keras.layers import Activation
import os

from IPython.display import Image, display, HTML


## Step 3: Download and Preprocess the Image
We will use an image of an elephant for our prediction. If the image file is not present, it will be downloaded from the web.

In [None]:
# Check if the image is already available
if not os.path.isfile('elephant.jpg'):
    !wget https://upload.wikimedia.org/wikipedia/commons/f/f9/Zoorashia_elephant.jpg -O elephant.jpg

# Load and preprocess the image
img_path = 'elephant.jpg'
img = keras.utils.load_img(img_path, target_size=(224, 224))
x = keras.utils.img_to_array(img)
x = np.expand_dims(x, axis=0)  # Add batch dimension
x = preprocess_input(x)  # Preprocess image for ResNet50


## Step 4: Load the Pre-trained Model
We will use the ResNet50 model pre-trained on ImageNet to make predictions.

In [None]:
# Load the ResNet50 model without the final classification layer
model = ResNet50(weights='imagenet', classifier_activation=None)

# Make a prediction
preds = model.predict(x)

# Decode the predictions to show the top 3 predictions
print('Predicted:', decode_predictions(preds, top=3)[0])


## Step 5: Split the Model into Nested Models
The goal is to break down the ResNet50 model into smaller, modular nested models. Each nested model will correspond to a part of the original model up to a specific layer. The split will be based on the activations of certain layers.

**Identify Layers to Split**
We will first identify the layers with activation functions (ReLU layers) and choose some layers to use as split points. For simplicity, let's pick layers at indices [0, 4, 8, 12, -1].

In [None]:
import keras_custom
from keras_custom.model import get_nested_model

# Identify activation layers (ReLU) in the model
relu_name = [e.name for e in model.layers if isinstance(e, Activation) and e.name.split('_')[-1] == 'out']

# Select layers to split at
indices = [0, 4, 8, 12, -1]
split = [relu_name[i] for i in indices] + [model.layers[-1].name]


**Create Nested Models**

Now, we will create a list of nested models by using the selected layers for the splits. Each nested model is built starting from the previous layer.

In [9]:
if not os.path.isfile('elephant.jpg'):
    !wget https://upload.wikimedia.org/wikipedia/commons/f/f9/Zoorashia_elephant.jpg -O elephant.jpg

In [8]:
model = ResNet50(weights='imagenet', classifier_activation=None)

img_path = 'elephant.jpg'
img = keras.utils.load_img(img_path, target_size=(224, 224))
x = keras.utils.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print('Predicted:', decode_predictions(preds, top=3)[0])

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 204ms/step
Predicted: [('n02504013', 'Indian_elephant', 16.045454), ('n02504458', 'African_elephant', 14.072982), ('n01871265', 'tusker', 13.382182)]


In [15]:
import keras_custom
from keras_custom.model import get_nested_model

In [16]:
relu_name = [e.name for e in model.layers if isinstance(e, Activation) and e.name.split('_')[-1]=='out']

In [19]:
indices=[0, 4, 8, 12, -1]
split = [relu_name[i] for i in indices]+[model.layers[-1].name]

In [24]:
layer_in = None
input_shape_wo_batch = list(model.input.shape[1:])
nested_models=[]
for name in split:
    layer_out = model.get_layer(name)
    nested_model = get_nested_model(model, layer_out, layer_in, input_shape_wo_batch)
    layer_in=layer_out
    nested_models.append(nested_model)

<KerasTensor shape=(None, 3, 224, 224), dtype=float32, sparse=False, name=keras_tensor_745>
<KerasTensor shape=(None, 256, 56, 56), dtype=float32, sparse=False, name=keras_tensor_770>
<KerasTensor shape=(None, 512, 28, 28), dtype=float32, sparse=False, name=keras_tensor_925>
<KerasTensor shape=(None, 1024, 14, 14), dtype=float32, sparse=False, name=keras_tensor_1080>
<KerasTensor shape=(None, 1024, 14, 14), dtype=float32, sparse=False, name=keras_tensor_1231>
<KerasTensor shape=(None, 2048, 7, 7), dtype=float32, sparse=False, name=keras_tensor_1310>


we can create the same models but as a sequential of nested models

In [25]:
model_seq = Sequential(layers=nested_models)

In [28]:
preds_ = model_seq.predict(x)

np.testing.assert_almost_equal(preds, preds_)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 151ms/step


We obtain the same prediction

In [29]:
print('Predicted:', decode_predictions(preds, top=3)[0])

Predicted: [('n02504013', 'Indian_elephant', 16.045454), ('n02504458', 'African_elephant', 14.072982), ('n01871265', 'tusker', 13.382182)]


In [34]:
dot_img_file_backward = './ResNet50.png'
keras.utils.plot_model(model_seq, to_file=dot_img_file_backward, show_shapes=True, show_layer_names=True)

display(HTML('<div style="text-align: center;"><img src="{}" width="800"/></div>'.format(dot_img_file_backward)))