In [1]:
import tensorflow as tf
from tensorflow.python import ipu

In [2]:
# Configure the IPU device.
config = ipu.config.IPUConfig()
config.auto_select_ipus = 4
config.configure_ipu_system()

In [3]:
strategy = ipu.ipu_strategy.IPUStrategy()
with strategy.scope():

  from tensorflow.keras.applications.resnet50 import ResNet50
  model = ResNet50(weights='imagenet')

  # Get the individual assignments - note that they are returned in post-order.
  assignments = model.get_pipeline_stage_assignment()

  # Iterate over them and set their pipeline stages.
  stage_id = 0
  for assignment in assignments:
    assignment.pipeline_stage = stage_id
    # Split the model on the `conv4_block2_add` layer.
    if assignment.layer.name.startswith("conv2_block1_add"):
      stage_id = 1
    if assignment.layer.name.startswith("conv2_block2_add"):
      stage_id = 2
    if assignment.layer.name.startswith("conv3_block1_add"):
      stage_id = 3
    if assignment.layer.name.startswith("conv3_block2_add"):
      stage_id = 4

  # Set the assignments to the model.
  model.set_pipeline_stage_assignment(assignments)

  model.print_pipeline_stage_assignment_summary()

Model: "resnet50"
_________________________________________________________________________________________
Layer (type) (node index)          Input Layers                        Pipeline Stage    
conv1_pad (ZeroPadding2D) (0)      input_1                             0                 
_________________________________________________________________________________________
conv1_conv (Conv2D) (0)            conv1_pad                           0                 
_________________________________________________________________________________________
conv1_bn (BatchNormalization) (0)  conv1_conv                          0                 
_________________________________________________________________________________________
conv1_relu (Activation) (0)        conv1_bn                            0                 
_________________________________________________________________________________________
pool1_pad (ZeroPadding2D) (0)      conv1_relu                          0          