-
Notifications
You must be signed in to change notification settings - Fork 177
/
train.py
128 lines (99 loc) · 5.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import absolute_import, division, print_function
import tensorflow as tf
import tensorflow.keras as nn
import math
import argparse
from configuration import IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS, \
EPOCHS, BATCH_SIZE, save_model_dir, save_every_n_epoch
from prepare_data import generate_datasets, load_and_preprocess_image
from models import get_model
def print_model_summary(network):
network.build(input_shape=(None, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS))
network.summary()
def process_features(features, data_augmentation):
image_raw = features['image_raw'].numpy()
image_tensor_list = []
for image in image_raw:
image_tensor = load_and_preprocess_image(image, data_augmentation=data_augmentation)
image_tensor_list.append(image_tensor)
images = tf.stack(image_tensor_list, axis=0)
labels = features['label'].numpy()
return images, labels
parser = argparse.ArgumentParser()
parser.add_argument("--idx", default=0, type=int)
if __name__ == '__main__':
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
print(e)
args = parser.parse_args()
# get the dataset
train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count = generate_datasets()
# create model
model = get_model(args.idx)
print_model_summary(network=model)
# define loss and optimizer
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = nn.optimizers.Adam(learning_rate=1e-3)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
# @tf.function
def train_step(image_batch, label_batch):
with tf.GradientTape() as tape:
predictions = model(image_batch, training=True)
loss = loss_object(y_true=label_batch, y_pred=predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))
train_loss.update_state(values=loss)
train_accuracy.update_state(y_true=label_batch, y_pred=predictions)
# @tf.function
def valid_step(image_batch, label_batch):
predictions = model(image_batch, training=False)
v_loss = loss_object(label_batch, predictions)
valid_loss.update_state(values=v_loss)
valid_accuracy.update_state(y_true=label_batch, y_pred=predictions)
# start training
for epoch in range(EPOCHS):
step = 0
for features in train_dataset:
step += 1
images, labels = process_features(features, data_augmentation=True)
train_step(images, labels)
print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch,
EPOCHS,
step,
math.ceil(train_count / BATCH_SIZE),
train_loss.result().numpy(),
train_accuracy.result().numpy()))
for features in valid_dataset:
valid_images, valid_labels = process_features(features, data_augmentation=False)
valid_step(valid_images, valid_labels)
print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, "
"valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch,
EPOCHS,
train_loss.result().numpy(),
train_accuracy.result().numpy(),
valid_loss.result().numpy(),
valid_accuracy.result().numpy()))
train_loss.reset_states()
train_accuracy.reset_states()
valid_loss.reset_states()
valid_accuracy.reset_states()
if epoch % save_every_n_epoch == 0:
model.save_weights(filepath=save_model_dir+"epoch-{}".format(epoch), save_format='tf')
# save weights
model.save_weights(filepath=save_model_dir+"model", save_format='tf')
# save the whole model
# tf.saved_model.save(model, save_model_dir)
# convert to tensorflow lite format
# model._set_inputs(inputs=tf.random.normal(shape=(1, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS)))
# converter = tf.lite.TFLiteConverter.from_keras_model(model)
# tflite_model = converter.convert()
# open("converted_model.tflite", "wb").write(tflite_model)