### Transfer learning
https://github.com/hasibzunair/segmentation_models

In [1]:
from segmentation_models import Unet

Using TensorFlow backend.


Segmentation Models: using `keras` framework.


#### TRAIN FROM SCRATCH

In [14]:
# define number of channels
N = 1

# define model
BACKBONE = 'resnet34'
model = Unet(backbone_name=BACKBONE, encoder_weights=None, input_shape=(320, 320, N))

model.summary()

Model: "model_8"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
data (InputLayer)               (None, 320, 320, 1)  0                                            
__________________________________________________________________________________________________
bn_data (BatchNormalization)    (None, 320, 320, 1)  3           data[0][0]                       
__________________________________________________________________________________________________
zero_padding2d_87 (ZeroPadding2 (None, 326, 326, 1)  0           bn_data[0][0]                    
__________________________________________________________________________________________________
conv0 (Conv2D)                  (None, 160, 160, 64) 3136        zero_padding2d_87[0][0]          
____________________________________________________________________________________________

#### Train with pretrained weights

In [3]:
from segmentation_models import Unet
from keras.layers import Input, Conv2D
from keras.models import Model

# define number of channels
N = 1

# setup base pretrained model
base_model = Unet(backbone_name='resnet34', encoder_weights='imagenet')

inp = Input(shape=(320, 320, N))
l1 = Conv2D(3, (1, 1))(inp) # map N channels data to 3 channels
out = base_model(l1)

model = Model(inp, out, name=base_model.name)
print(model.summary())

Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 320, 320, 1)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 320, 320, 3)       6         
_________________________________________________________________
model_4 (Model)              multiple                  24456154  
Total params: 24,456,160
Trainable params: 24,438,810
Non-trainable params: 17,350
_________________________________________________________________
None


### Fine tuning

In [6]:
#from segmentation_models import Unet
#from segmentation_models.utils import set_trainable

# freeze_encoder=True -->> backbone is frozen

#model = Unet(backbone_name='resnet34', encoder_weights='imagenet', freeze_encoder=True)
#model.compile('Adam', 'binary_crossentropy', ['binary_accuracy'])

# pretrain model decoder
#model.fit(x, y, epochs=2)

# release all layers for training
#set_trainable(model) # set all layers trainable and recompile model

# continue training
#model.fit(x, y, epochs=100)


### FPN

In [4]:
import segmentation_models as sm

# define number of channels
N = 1

# setup base pretrained model
base_model = sm.FPN(backbone_name='resnet18', encoder_weights='imagenet')

inp = Input(shape=(320, 320, N))
l1 = Conv2D(3, (1, 1))(inp) # map N channels data to 3 channels
out = base_model(l1)

model = Model(inp, out, name=base_model.name)
print(model.summary())


Model: "model_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 320, 320, 1)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 320, 320, 3)       6         
_________________________________________________________________
model_6 (Model)              multiple                  13838430  
Total params: 13,838,436
Trainable params: 13,828,190
Non-trainable params: 10,246
_________________________________________________________________
None


In [13]:
#x = model.layers[-1]
#for l in x.layers:
    #print(l)