-
Notifications
You must be signed in to change notification settings - Fork 0
/
unet3.py
80 lines (61 loc) · 2.73 KB
/
unet3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Neural Network implementation for image segmentation
Copyright (c) 2022 Giansalvo Gusinu <profgusinu@gmail.com>
Copyright (c) 2021 Nikhil Tomar
Code adapted from:
https://github.com/nikhilroxtomar/Semantic-Segmentation-Architecture/blob/main/TensorFlow/unet.py
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
TRANSF_LEARN_FREEZE_ENCODER = "freeze_encoder"
def conv_block(input, num_filters):
x = Conv2D(num_filters, 3, padding="same")(input)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(num_filters, 3, padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
return x
def encoder_block(input, num_filters):
x = conv_block(input, num_filters)
p = MaxPool2D((2, 2))(x)
return x, p
def decoder_block(input, skip_features, num_filters):
x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
x = Concatenate()([x, skip_features])
x = conv_block(x, num_filters)
return x
def create_model_UNet3(input_shape=(128, 128, 3), classes=3, transfer_learning=None):
print("unet3.py: WARNING parameter 'transfer_learning' not used. Reserved for future uses.") # TODO parameter not used
inputs = Input(input_shape)
s1, p1 = encoder_block(inputs, 64)
s2, p2 = encoder_block(p1, 128)
s3, p3 = encoder_block(p2, 256)
s4, p4 = encoder_block(p3, 512)
if transfer_learning == TRANSF_LEARN_FREEZE_ENCODER:
print("unet2.py: detected transfer learning: schedule imagenet weights to be loaded.")
s1.trainable = False
s2.trainable = False
s3.trainable = False
s4.trainable = False
b1 = conv_block(p4, 1024)
d1 = decoder_block(b1, s4, 512)
d2 = decoder_block(d1, s3, 256)
d3 = decoder_block(d2, s2, 128)
d4 = decoder_block(d3, s1, 64)
outputs = Conv2D(classes, 1, padding="same", activation="sigmoid")(d4)
model = Model(inputs, outputs, name="U-Net3")
return model
if __name__ == "__main__":
input_shape = (512, 512, 3)
model = create_model_UNet3(input_shape)
model.summary()