# Transfer learning
Transfer learning is a vital part for saving time during model training

In [5]:
import tifffile as tif
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras
import pandas as pd
import carreno.nn.unet as unet
import carreno.nn.layers as ly

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 2D weights to 3D
We would like to transfer the weights from a 2D UNet to a 3D UNet

In [6]:
# create 2D UNet
unet2D = unet.UNet([64,64,1], n_class=3, depth=3, top_activation="relu")
model = unet2D
model.summary()

create norm relu
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 64, 64, 1)]  0           []                               
                                                                                                  
 conv2d_15 (Conv2D)             (None, 64, 64, 64)   640         ['input_2[0][0]']                
                                                                                                  
 batch_normalization_14 (BatchN  (None, 64, 64, 64)  256         ['conv2d_15[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_15 (Activation)     (None, 64, 64, 64)   0           ['batch_no

In [7]:
# set fake weights
layer = model.layers[1].get_weights()[0]
layer[:] = 1  # all weights will be 1
model.layers[1].set_weights([layer, model.layers[1].get_weights()[1]])

In [8]:
# create 3D UNet from 2D UNet
unet3D = ly.model2D_to_3D(model, 16)

In [9]:
# if everything went well, our fake weights should be averaged in this layer
layer3D = unet3D.layers[1].get_weights()[0]
print("Old shape :", layer.shape)
print("New shape :", layer3D.shape)
print('---')
print("Old weights :", np.unique(layer))
print("New weights :", np.unique(layer3D))

Old shape : (3, 3, 1, 64)
New shape : (3, 3, 3, 1, 64)
---
Old weights : [1.]
New weights : [0.33333334]


In [10]:
unet3D.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 16, 64, 64,  0           []                               
                                 1)]                                                              
                                                                                                  
 conv3d (Conv3D)                (None, 16, 64, 64,   1792        ['input_3[0][0]']                
                                64)                                                               
                                                                                                  
 batch_normalization_28 (BatchN  (None, 16, 64, 64,   256        ['conv3d[0][0]']                 
 ormalization)                  64)                                                         

In [14]:
# we could use an approach which map every layer name with a 2D and 3D fonction for a very general conversion
inp = tf.keras.layers.Input([16, 64, 64, 1])
print("nb layers", len(model.layers))
for i, layer in enumerate(model.layers):
    name = layer.name.rsplit('_', 1)[0]
    print('-layer{:<2} {:<20}'.format(i, name), end="")
    
    t = "?"
    input_name = 'input'
    conv2d_name = 'conv2d'
    conv2d_transpose_name = 'conv2d_transpose'
    batch_norm_name = 'batch_normalization'
    leaky_relu_name = 'leaky_re_lu'
    max_pooling2d_name = 'max_pooling2d'
    concatenate_name = 'concatenate'
    
    if input_name == name:
        t = input_name
    elif conv2d_name == name:
        t = conv2d_name
    elif conv2d_transpose_name == name:
        t = conv2d_transpose_name
    elif batch_norm_name == name:
        t = batch_norm_name
    elif leaky_relu_name == name:
        t = leaky_relu_name
    elif max_pooling2d_name == name:
        t = max_pooling2d_name
    elif concatenate_name == name:
        t = concatenate_name
    
    print(' type {:<20}'.format(t), 'do something?', '2d' in name or name == concatenate_name)

nb layers 69
-layer0  input                type input                do something? False
-layer1  conv2d               type conv2d               do something? True
-layer2  batch_normalization  type batch_normalization  do something? False
-layer3  activation           type ?                    do something? False
-layer4  dropout              type ?                    do something? False
-layer5  conv2d               type conv2d               do something? True
-layer6  batch_normalization  type batch_normalization  do something? False
-layer7  activation           type ?                    do something? False
-layer8  dropout              type ?                    do something? False
-layer9  max_pooling2d        type max_pooling2d        do something? True
-layer10 conv2d               type conv2d               do something? True
-layer11 batch_normalization  type batch_normalization  do something? False
-layer12 activation           type ?                    do something? False
-la

In [19]:
# test save and load
unet3D.save_weights("test.h5")
newUnet3D = unet.UNet([16, 64,64,1], n_class=3, depth=3, top_activation="relu")
newUnet3D.load_weights("test.h5")
newUnet3D.summary()

Model: "model_7"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_10 (InputLayer)          [(None, 16, 64, 64,  0           []                               
                                 1)]                                                              
                                                                                                  
 conv3d_60 (Conv3D)             (None, 16, 64, 64,   1792        ['input_10[0][0]']               
                                64)                                                               
                                                                                                  
 batch_normalization_98 (BatchN  (None, 16, 64, 64,   256        ['conv3d_60[0][0]']              
 ormalization)                  64)                                                         

# Pratical test

In [32]:
unet2dLoader = unet.UNet([96,96,3], n_class=3, depth=4, top_activation="relu", n_feat=64, norm_order=1, dropout=0.4, activation='relu', backbone='vgg16')
unet2dLoader.load_weights("D:/Etude/LINUM/resultat/unet2d_optim_w_unlabeled7.h5")
transfer3d = ly.model2D_to_3D(unet2dLoader, 16)
transfer3d.save_weights("test2.h5")
transfer3d.summary()

Model: "model_42"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_45 (InputLayer)          [(None, 16, 96, 96,  0           []                               
                                 3)]                                                              
                                                                                                  
 conv3d_297 (Conv3D)            (None, 16, 96, 96,   5248        ['input_45[0][0]']               
                                64)                                                               
                                                                                                  
 conv3d_298 (Conv3D)            (None, 16, 96, 96,   110656      ['conv3d_297[0][0]']             
                                64)                                                        

In [34]:
# change input shape
unet3dLoader = unet.UNet([16,64,64,3], n_class=3, depth=4, top_activation="relu", n_feat=64, norm_order=1, dropout=0.4, activation='relu', backbone='vgg16')
unet3dLoader.load_weights("test2.h5")
unet3dLoader.summary()

Model: "model_48"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_49 (InputLayer)          [(None, 16, 64, 64,  0           []                               
                                 3)]                                                              
                                                                                                  
 conv3d_341 (Conv3D)            (None, 16, 64, 64,   5248        ['input_49[0][0]']               
                                64)                                                               
                                                                                                  
 conv3d_342 (Conv3D)            (None, 16, 64, 64,   110656      ['conv3d_341[1][0]']             
                                64)                                                        