-
Notifications
You must be signed in to change notification settings - Fork 509
/
callbacks.py
112 lines (91 loc) · 3.25 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Created by Matthias Mueller - Intel Intelligent Systems Lab - 2020
import os
import tensorflow as tf
def checkpoint_cb(checkpoint_path, steps_per_epoch=-1, num_epochs=10):
# Create a callback that saves the model's weights every epochs
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_path, "cp-{epoch:04d}.ckpt"),
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
save_freq="epoch" if steps_per_epoch < 0 else int(num_epochs * steps_per_epoch),
)
return checkpoint_callback
def checkpoint_last_cb(checkpoint_path, steps_per_epoch=-1, num_epochs=10):
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_path, "cp-last.ckpt"),
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
save_freq="epoch" if steps_per_epoch < 0 else int(num_epochs * steps_per_epoch),
)
return checkpoint_callback
def checkpoint_best_train_cb(checkpoint_path, steps_per_epoch=-1, num_epochs=10):
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_path, "cp-best-train.ckpt"),
monitor="loss",
verbose=0,
save_best_only=True,
save_weights_only=False,
mode="auto",
save_freq="epoch" if steps_per_epoch < 0 else int(num_epochs * steps_per_epoch),
)
return checkpoint_callback
def checkpoint_best_val_cb(checkpoint_path, steps_per_epoch=-1, num_epochs=10):
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_path, "cp-best-val.ckpt"),
monitor="val_loss",
verbose=0,
save_best_only=True,
save_weights_only=False,
mode="auto",
save_freq="epoch" if steps_per_epoch < 0 else int(num_epochs * steps_per_epoch),
)
return checkpoint_callback
def tensorboard_cb(log_path):
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_path,
histogram_freq=0,
write_graph=True,
write_images=True,
update_freq="epoch",
profile_batch=2,
embeddings_freq=0,
embeddings_metadata=None,
)
return tensorboard_callback
def logger_cb(log_path, append=False):
logger_callback = tf.keras.callbacks.CSVLogger(
os.path.join(log_path, "log.csv"), append=append
)
return logger_callback
def early_stopping_cb():
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0,
patience=20,
verbose=0,
mode="auto",
baseline=None,
restore_best_weights=False,
)
return early_stopping_callback
def reduce_lr_cb():
reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.3, patience=2, min_lr=0.0001
)
return reduce_lr_callback
def lr_schedule_cb():
return tf.keras.callbacks.LearningRateScheduler(scheduler)
# This function defines a custom learning schedule.
def scheduler(epoch):
if epoch < 10:
return 0.0002
elif epoch < 20:
return 0.0001
else:
return 0.00005