<a href="https://colab.research.google.com/github/jadevaibhav/TF-keras_model_edit/blob/master/Switching_intermediate_layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa

In [None]:
xception = tf.keras.applications.xception.Xception(include_top=False, weights='imagenet',input_shape=(224,224,3))
#xception.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5


**To edit a particular type of layer throughout architecture is difficult task for non-linear models. As the all inbound and outbound node connections of the graph need to be changed. To facilitate this, I have demonstrated the method with use case of switching the batch normilization layer with instance normalization layer. I am recreating the model using existing graphs and storing the outputs in dictionary with key-value pair as previous outbound node with new one.  **

In [None]:
def xcp_edit():

  out_layer = {}
  x = xception.input 

  out_layer[xception.layers[0].output.name] = x
  for i in range(1,len(xception.layers[1:])):
    layer = xception.layers[i]
    if layer.name[-2:] == "bn" or layer.name[:19]=="batch_normalization":
      #print("InstNorm here")
      #print(layer.input.name,out_layer[layer.input.name].name)
      x = out_layer[layer.input.name]
      x = tfa.layers.InstanceNormalization()(x)
      out_layer[layer.output.name] = x
    else:
      #print(layer.name)
     
      if isinstance(layer.input, list):
        in_l = []
        for inp in layer.input:
          in_l.append(out_layer[inp.name])
        x = layer(in_l)
        out_layer[layer.output.name] = x
      else:
        #print(layer.input.name,out_layer[layer.input.name].name)
        x = layer(out_layer[layer.input.name])
        out_layer[layer.output.name] = x  

  model = tf.keras.Model(inputs=xception.input,outputs=x)
  #model.trainable = False
  model.summary()
  return model

In [None]:
model = xcp_edit()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 111, 111, 32) 864         input_1[0][0]                    
__________________________________________________________________________________________________
instance_normalization_40 (Inst (None, 111, 111, 32) 64          block1_conv1[3][0]               
__________________________________________________________________________________________________
block1_conv1_act (Activation)   (None, 111, 111, 32) 0           instance_normalization_40[0][0]  
______________________________________________________________________________________________

In [None]:
xception.summary()

**Another usecase of switching conv layers with seperable conv layers in DenseNet architecture. The architecture has one 1x1 conv and one 3x3 conv, we will replace only 3x3 conv op. With this switch, we go from 7 mil to 5 mil parameters. Downside is, these sep conv weights are randomly intiated, hence retraining is required.**

In [None]:
dense = tf.keras.applications.densenet.DenseNet121(include_top=False, weights='imagenet',input_shape=(224,224,3))

In [None]:
dense.summary()

Model: "densenet121"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 230, 230, 3)  0           input_3[0][0]                    
__________________________________________________________________________________________________
conv1/conv (Conv2D)             (None, 112, 112, 64) 9408        zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
conv1/bn (BatchNormalization)   (None, 112, 112, 64) 256         conv1/conv[0][0]                 
________________________________________________________________________________________

In [None]:
dense.get_layer('conv5_block16_2_conv').get_config()

In [None]:
def dense_edit():

  out_layer = {}
  x = dense.input 

  out_layer[dense.layers[0].output.name] = x
  for i in range(1,len(dense.layers[1:])):
    layer = dense.layers[i]
    if layer.name[-7:] == "_2_conv":
      #print("InstNorm here")
      #print(layer.input.name,out_layer[layer.input.name].name)
      x = out_layer[layer.input.name]
      x = tf.keras.layers.SeparableConv2D(32,(3,3),padding='same')(x)
      out_layer[layer.output.name] = x
    else:
      #print(layer.name)
     
      if isinstance(layer.input, list):
        in_l = []
        for inp in layer.input:
          in_l.append(out_layer[inp.name])
        x = layer(in_l)
        out_layer[layer.output.name] = x
      else:
        #print(layer.input.name,out_layer[layer.input.name].name)
        x = layer(out_layer[layer.input.name])
        out_layer[layer.output.name] = x  

  model = tf.keras.Model(inputs=dense.input,outputs=x)

  model.summary()
  return model
         

In [None]:
dense_model = dense_edit()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 230, 230, 3)  0           input_3[0][0]                    
__________________________________________________________________________________________________
conv1/conv (Conv2D)             (None, 112, 112, 64) 9408        zero_padding2d_2[1][0]           
__________________________________________________________________________________________________
conv1/bn (BatchNormalization)   (None, 112, 112, 64) 256         conv1/conv[1][0]                 
____________________________________________________________________________________________