-
Notifications
You must be signed in to change notification settings - Fork 34
/
simclr_tf.py
235 lines (196 loc) · 8.9 KB
/
simclr_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Copyright 2021 The FastEstimator Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The FastEstimator implementation of SimCLR with ResNet9 on CIFAIR10.
This code took reference from google implementation (https://github.com/google-research/simclr).
Note that we use the ciFAIR10 dataset instead (https://cvjena.github.io/cifair/)
"""
import tempfile
import tensorflow as tf
from tensorflow.keras import layers
import fastestimator as fe
from fastestimator.dataset.data.cifair10 import load_data
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import ColorJitter, GaussianBlur, ToFloat, ToGray
from fastestimator.op.tensorop import TensorOp
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.io import ModelSaver
from fastestimator.trace.metric import Accuracy
def ResNet9(input_size=(32, 32, 3), head_len=128, classes=10):
"""A small 9-layer ResNet Tensorflow model for cifar10 image classification.
The model architecture is from https://github.com/davidcpage/cifar10-fast
Args:
input_size: The size of the input tensor (height, width, channels).
classes: The number of outputs the model should generate.
Raises:
ValueError: Length of `input_size` is not 3.
ValueError: `input_size`[0] or `input_size`[1] is not a multiple of 16.
Returns:
A TensorFlow ResNet9 model.
"""
# prep layers
inp = layers.Input(shape=input_size)
x = layers.Conv2D(64, 3, padding='same')(inp)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
# layer1
x = layers.Conv2D(128, 3, padding='same')(x)
x = layers.MaxPool2D()(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Add()([x, residual(x, 128)])
# layer2
x = layers.Conv2D(256, 3, padding='same')(x)
x = layers.MaxPool2D()(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
# layer3
x = layers.Conv2D(512, 3, padding='same')(x)
x = layers.MaxPool2D()(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Add()([x, residual(x, 512)])
# layers4
x = layers.GlobalMaxPool2D()(x)
code = layers.Flatten()(x)
p_head = layers.Dense(head_len)(code)
model_con = tf.keras.Model(inputs=inp, outputs=p_head)
s_head = layers.Dense(classes)(code)
s_head = layers.Activation('softmax', dtype='float32')(s_head)
model_finetune = tf.keras.Model(inputs=inp, outputs=s_head)
return model_con, model_finetune
def residual(x, num_channel):
"""A ResNet unit for ResNet9.
Args:
x: Input Keras tensor.
num_channel: The number of layer channel.
Return:
Output Keras tensor.
"""
x = layers.Conv2D(num_channel, 3, padding='same')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Conv2D(num_channel, 3, padding='same')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
return x
class NTXentOp(TensorOp):
def __init__(self, arg1, arg2, outputs, temperature=1.0, mode=None):
super().__init__(inputs=(arg1, arg2), outputs=outputs, mode=mode)
self.temperature = temperature
def forward(self, data, state):
arg1, arg2 = data
loss, logit, label = NTXent(arg1, arg2, self.temperature)
return loss, logit, label
def NTXent(A, B, temperature):
large_number = 1e9
batch_size = tf.shape(A)[0]
A = tf.math.l2_normalize(A, -1)
B = tf.math.l2_normalize(B, -1)
mask = tf.one_hot(tf.range(batch_size), batch_size)
labels = tf.one_hot(tf.range(batch_size), 2 * batch_size)
aa = tf.matmul(A, A, transpose_b=True) / temperature
aa = aa - mask * large_number
ab = tf.matmul(A, B, transpose_b=True) / temperature
bb = tf.matmul(B, B, transpose_b=True) / temperature
bb = bb - mask * large_number
ba = tf.matmul(B, A, transpose_b=True) / temperature
loss_a = tf.nn.softmax_cross_entropy_with_logits(labels, tf.concat([ab, aa], 1))
loss_b = tf.nn.softmax_cross_entropy_with_logits(labels, tf.concat([ba, bb], 1))
loss = tf.reduce_mean(loss_a + loss_b)
return loss, ab, labels
def pretrain_model(epochs, batch_size, train_steps_per_epoch, save_dir):
# step 1: prepare dataset
train_data, test_data = load_data()
pipeline = fe.Pipeline(
train_data=train_data,
batch_size=batch_size,
ops=[
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x"),
# augmentation 1
RandomCrop(32, 32, image_in="x", image_out="x_aug"),
Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5),
Sometimes(
ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
prob=0.8),
Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2),
Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)),
prob=0.5),
ToFloat(inputs="x_aug", outputs="x_aug"),
# augmentation 2
RandomCrop(32, 32, image_in="x", image_out="x_aug2"),
Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5),
Sometimes(
ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
prob=0.8),
Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2),
Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)),
prob=0.5),
ToFloat(inputs="x_aug2", outputs="x_aug2")
])
# step 2: prepare network
model_con, model_finetune = fe.build(model_fn=ResNet9, optimizer_fn=["adam", "adam"])
network = fe.Network(ops=[
ModelOp(model=model_con, inputs="x_aug", outputs="y_pred"),
ModelOp(model=model_con, inputs="x_aug2", outputs="y_pred2"),
NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"]),
UpdateOp(model=model_con, loss_name="NTXent")
])
# step 3: prepare estimator
traces = [
Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"),
ModelSaver(model=model_con, save_dir=save_dir),
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch)
estimator.fit()
return model_finetune
def finetune_model(model, epochs, batch_size, train_steps_per_epoch):
train_data, test_data = load_data()
train_data = train_data.split(0.1)
pipeline = fe.Pipeline(train_data=train_data,
eval_data=test_data,
batch_size=batch_size,
ops=[
ToFloat(inputs="x", outputs="x"),
])
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=["y_pred", "y"], outputs="ce"),
UpdateOp(model=model, loss_name="ce")
])
traces = [
Accuracy(true_key="y", pred_key="y_pred"),
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch)
estimator.fit()
def fastestimator_run(epochs_pretrain=50,
epochs_finetune=10,
batch_size=512,
train_steps_per_epoch=None,
save_dir=tempfile.mkdtemp()):
model_finetune = pretrain_model(epochs_pretrain, batch_size, train_steps_per_epoch, save_dir)
finetune_model(model_finetune, epochs_finetune, batch_size, train_steps_per_epoch)
if __name__ == "__main__":
fastestimator_run()