# Transfer learning
Transfer learning is a vital part for saving time during model training

In [1]:
import tifffile as tif
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import carreno.nn.unet as unet

## 2D weights to 3D
We would like to transfer the weights from a 2D UNet to a 3D UNet

In [2]:
help(unet.UNet.__init__)

Help on function __init__ in module carreno.nn.unet:

__init__(self, shape, n_class=3, depth=3, n_feat=32)
    Create a UNet architecture
    Parameters
    ----------
    shape : (int, int, int)
        Image shape. Even if grayscale, we must have a color channel
    n_class : int
        Number of unique labels
    depth : int
        UNet number of levels (nb of encoder block + 1)
    n_feat : int
        Number of features for the first encoder block (will increase and decrease according to UNet architecture)
    Returns
    -------
    model : tf.keras.Model
        Keras model waiting to be compiled for training



In [3]:
# create 2D UNet
unet2D = unet.UNet([64,64,1], n_class=3, depth=3)
model = unet2D.model
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 64, 64, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 64, 64, 32)   320         ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 64, 64, 32)  128         ['conv2d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 leaky_re_lu (LeakyReLU)        (None, 64, 64, 32)   0           ['batch_normalization[0][0]']

In [4]:
# set fake weights
layer = model.layers[1].get_weights()[0]
layer[:] = 1  # all weights will be 1
model.layers[1].set_weights([layer, model.layers[1].get_weights()[1]])

In [5]:
# create 3D UNet from 2D UNet
unet3D = unet.unet2D_to_unet3D(model, [64,64,64,1])

In [6]:
# if everything went well, our fake weights should be averaged in this layer
layer3D = unet3D.layers[1].get_weights()[0]
print("Old shape :", layer.shape)
print("New shape :", layer3D.shape)
print('---')
print("Old weights :", np.unique(layer))
print("New weights :", np.unique(layer3D))

Old shape : (3, 3, 1, 32)
New shape : (3, 3, 3, 1, 32)
---
Old weights : [1.]
New weights : [0.33333334]


In [7]:
unet3D.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 64, 64, 64,  0           []                               
                                 1)]                                                              
                                                                                                  
 conv3d (Conv3D)                (None, 64, 64, 64,   896         ['input_2[0][0]']                
                                32)                                                               
                                                                                                  
 batch_normalization_10 (BatchN  (None, 64, 64, 64,   128        ['conv3d[0][0]']                 
 ormalization)                  32)                                                         

In [8]:
# we could use an approach which map every layer name with a 2D and 3D fonction for a very general conversion
inp = tf.keras.layers.Input([64, 64, 64, 1])
print("nb layers", len(model.layers))
for i, layer in enumerate(model.layers):
    name = layer.name.rsplit('_', 1)[0]
    print('-layer{:<2} {:<20}'.format(i, name), end="")
    
    t = "?"
    input_name = 'input'
    conv2d_name = 'conv2d'
    conv2d_transpose_name = 'conv2d_transpose'
    batch_norm_name = 'batch_normalization'
    leaky_relu_name = 'leaky_re_lu'
    max_pooling2d_name = 'max_pooling2d'
    concatenate_name = 'concatenate'
    
    if input_name == name:
        t = input_name
    elif conv2d_name == name:
        t = conv2d_name
    elif conv2d_transpose_name == name:
        t = conv2d_transpose_name
    elif batch_norm_name == name:
        t = batch_norm_name
    elif leaky_relu_name == name:
        t = leaky_relu_name
    elif max_pooling2d_name == name:
        t = max_pooling2d_name
    elif concatenate_name == name:
        t = concatenate_name
    
    print(' type {:<20}'.format(t), 'do something?', '2d' in name or name == concatenate_name)

nb layers 38
-layer0  input                type input                do something? False
-layer1  conv2d               type conv2d               do something? True
-layer2  batch_normalization  type batch_normalization  do something? False
-layer3  leaky_re_lu          type leaky_re_lu          do something? False
-layer4  conv2d               type conv2d               do something? True
-layer5  batch_normalization  type batch_normalization  do something? False
-layer6  leaky_re_lu          type leaky_re_lu          do something? False
-layer7  max_pooling2d        type max_pooling2d        do something? True
-layer8  conv2d               type conv2d               do something? True
-layer9  batch_normalization  type batch_normalization  do something? False
-layer10 leaky_re_lu          type leaky_re_lu          do something? False
-layer11 conv2d               type conv2d               do something? True
-layer12 batch_normalization  type batch_normalization  do something? False
-lay