-
-
Notifications
You must be signed in to change notification settings - Fork 108
/
custom_unet.py
75 lines (62 loc) · 2.54 KB
/
custom_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from keras.models import Model
from keras.layers import BatchNormalization, Conv2D, Conv2DTranspose, MaxPooling2D, Dropout, UpSampling2D, Input, concatenate
def upsample_conv(filters, kernel_size, strides, padding):
return Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)
def upsample_simple(filters, kernel_size, strides, padding):
return UpSampling2D(strides)
def conv2d_block(
inputs,
use_batch_norm=True,
dropout=0.3,
filters=16,
kernel_size=(3,3),
activation='relu',
kernel_initializer='he_normal',
padding='same'):
c = Conv2D(filters, kernel_size, activation=activation, kernel_initializer=kernel_initializer, padding=padding) (inputs)
if use_batch_norm:
c = BatchNormalization()(c)
if dropout > 0.0:
c = Dropout(dropout)(c)
c = Conv2D(filters, kernel_size, activation=activation, kernel_initializer=kernel_initializer, padding=padding) (c)
if use_batch_norm:
c = BatchNormalization()(c)
return c
def custom_unet(
input_shape,
num_classes=1,
use_batch_norm=True,
upsample_mode='deconv', # 'deconv' or 'simple'
use_dropout_on_upsampling=False,
dropout=0.3,
dropout_change_per_layer=0.0,
filters=16,
num_layers=4,
output_activation='sigmoid'): # 'sigmoid' or 'softmax'
if upsample_mode=='deconv':
upsample=upsample_conv
else:
upsample=upsample_simple
# Build U-Net model
inputs = Input(input_shape)
x = inputs
down_layers = []
for l in range(num_layers):
x = conv2d_block(inputs=x, filters=filters, use_batch_norm=use_batch_norm, dropout=dropout)
down_layers.append(x)
x = MaxPooling2D((2, 2)) (x)
dropout += dropout_change_per_layer
filters = filters*2 # double the number of filters with each layer
x = conv2d_block(inputs=x, filters=filters, use_batch_norm=use_batch_norm, dropout=dropout)
if not use_dropout_on_upsampling:
dropout = 0.0
dropout_change_per_layer = 0.0
for conv in reversed(down_layers):
filters //= 2 # decreasing number of filters with each layer
dropout -= dropout_change_per_layer
x = upsample(filters, (2, 2), strides=(2, 2), padding='same') (x)
x = concatenate([x, conv])
x = conv2d_block(inputs=x, filters=filters, use_batch_norm=use_batch_norm, dropout=dropout)
outputs = Conv2D(num_classes, (1, 1), activation=output_activation) (x)
model = Model(inputs=[inputs], outputs=[outputs])
return model