-
Notifications
You must be signed in to change notification settings - Fork 0
/
unet2.py
103 lines (83 loc) · 4.12 KB
/
unet2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Neural Network implementation for image segmentation
Copyright (c) 2022 Giansalvo Gusinu <profgusinu@gmail.com>
Code adapted from TensorFlow Tutorial
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
import tensorflow as tf
# pip install -q git+https://github.com/tensorflow/examples.git
from tensorflow_examples.models.pix2pix import pix2pix
TRANSF_LEARN_IMAGENET_AND_FREEZE_ENCODER = "imagenet_freeze_encoder"
TRANSF_LEARN_IMAGENET_AND_FREEZE_DECODER = "imagenet_freeze_decoder"
TRANSF_LEARN_FREEZE_ENCODER = "freeze_encoder"
TRANSF_LEARN_FREEZE_ENCODER = "freeze_decoder"
def create_model_UNet2(output_channels:int, input_size=128, classes=3, transfer_learning=None):
print("unet2.py: WARNING parameter 'classes' not used. Reserved for future uses.") # TODO parameter classes not used
if transfer_learning == TRANSF_LEARN_IMAGENET_AND_FREEZE_ENCODER or transfer_learning == TRANSF_LEARN_IMAGENET_AND_FREEZE_DECODER:
print("unet2.py: detected transfer learning: schedule imagenet weights to be loaded.")
w = "imagenet"
else:
w = None
base_model = tf.keras.applications.MobileNetV2(input_shape=[input_size, input_size, 3],
include_top=False,
weights=w)
# Use the activations of these layers
layer_names = [
'block_1_expand_relu', # 64x64
'block_3_expand_relu', # 32x32
'block_6_expand_relu', # 16x16
'block_13_expand_relu', # 8x8
'block_16_project', # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]
# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)
if transfer_learning == TRANSF_LEARN_IMAGENET_AND_FREEZE_ENCODER or transfer_learning == TRANSF_LEARN_FREEZE_ENCODER:
print("unet2.py: detected transfer learning: freeze down_stack weights.")
base_model.trainable = False
# for layer in base_model.layers:
# print(layer.get_)
# layer.trainable = False
down_stack.trainable = False
# for layer in down_stack.layers:
# layer.trainable = False
up_stack = [
pix2pix.upsample(512, 3), # 4x4 -> 8x8
pix2pix.upsample(256, 3), # 8x8 -> 16x16
pix2pix.upsample(128, 3), # 16x16 -> 32x32
pix2pix.upsample(64, 3), # 32x32 -> 64x64
]
inputs = tf.keras.layers.Input(shape=[input_size, input_size, 3])
# Downsampling through the model
skips = down_stack(inputs)
x = skips[-1]
skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
x = up(x)
concat = tf.keras.layers.Concatenate()
x = concat([x, skip])
if transfer_learning == TRANSF_LEARN_IMAGENET_AND_FREEZE_DECODER:
concat.trainable = False
x.trainable = False
up.trainable = False
# This is the last layer of the model
last = tf.keras.layers.Conv2DTranspose(
filters=output_channels, kernel_size=3, strides=2,
padding='same') #64x64 -> 128x128
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x, name="U-Net2")