-
Notifications
You must be signed in to change notification settings - Fork 97
/
train.py
116 lines (100 loc) · 4.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import json
import argparse
from datetime import datetime
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
CSVLogger, ModelCheckpoint, EarlyStopping, ReduceLROnPlateau)
import paz.processors as pr
from paz.abstract import ProcessingSequence
from paz.pipelines import AugmentDetection
from paz.optimization import MultiBoxLoss
from open_images import OpenImagesV6
from model import SSD512Custom
root_path = os.path.expanduser('~')
DEFAULT_DATA_PATH = os.path.join(root_path, 'hand_dataset/hand_dataset/')
description = 'Training script for single-shot object detection models'
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--batch_size', default=32, type=int,
help='Batch size for training')
parser.add_argument('--evaluation_frequency', default=10, type=int,
help='evaluation frequency')
parser.add_argument('--stop_patience', default=5, type=int,
help='Early stop patience')
parser.add_argument('--reduce_patience', default=2, type=int,
help='Reduce learning rate patience')
parser.add_argument('--learning_rate', default=0.0001, type=float,
help='Initial learning rate for SGD')
parser.add_argument('--momentum', default=0.9, type=float,
help='Momentum for SGD')
parser.add_argument('--gamma_decay', default=0.1, type=float,
help='Gamma decay for learning rate scheduler')
parser.add_argument('--num_epochs', default=240, type=int,
help='Maximum number of epochs before finishing')
parser.add_argument('--AP_IOU', default=0.5, type=float,
help='Average precision IOU used for evaluation')
parser.add_argument('--save_path', default='experiments',
type=str, help='Path for writing model weights and logs')
parser.add_argument('--data_path', default=DEFAULT_DATA_PATH,
type=str, help='Path for writing model weights and logs')
parser.add_argument('--scheduled_epochs', nargs='+', type=int,
default=[110, 152], help='Epoch learning rate reduction')
parser.add_argument('--run_label', default='RUN_00', type=str,
help='Label used to distinguish between different runs')
args = parser.parse_args()
# loading datasets
path = os.path.join(root_path, 'fiftyone/open-images-v6/')
data_managers, datasets = [], []
for split in [pr.TRAIN, pr.VAL]:
data_manager = OpenImagesV6(path, split, ['background', 'Human hand'])
data = data_manager.load_data()
data_managers.append(data_manager)
datasets.append(data)
# instantiating model
num_classes = data_managers[0].num_classes
model = SSD512Custom(num_classes, trainable_base=True)
size = model.input_shape[1]
# Instantiating loss and metrics
# optimizer = SGD(args.learning_rate, args.momentum)
optimizer = Adam(args.learning_rate, amsgrad=True)
loss = MultiBoxLoss()
metrics = {'boxes': [loss.localization,
loss.positive_classification,
loss.negative_classification]}
model.compile(optimizer, loss.compute_loss, metrics)
# build augmentation pipelines
augmentators = []
for split in [pr.TRAIN, pr.VAL]:
augmentator = AugmentDetection(model.prior_boxes, split, num_classes, size)
augmentators.append(augmentator)
# EXPERIMENTAL: removes RandomSampleCrop
augmentators[0].augment_boxes.pop(2)
# build sequencers
sequencers = []
for data, processor in zip(datasets, augmentators):
sequencer = ProcessingSequence(processor, args.batch_size, data)
sequencers.append(sequencer)
# saving hyper-parameters and model summary
current_time = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
experiment_label = '_'.join([model.name, args.run_label, current_time])
experiment_path = os.path.join(args.save_path, experiment_label)
if not os.path.exists(experiment_path):
os.makedirs(experiment_path)
with open(os.path.join(experiment_path, 'hyperparameters.json'), 'w') as filer:
json.dump(args.__dict__, filer, indent=4)
with open(os.path.join(experiment_path, 'model_summary.txt'), 'w') as filer:
model.summary(print_fn=lambda x: filer.write(x + '\n'))
# setting additional callbacks
log = CSVLogger(os.path.join(experiment_path, 'optimization.log'))
stop = EarlyStopping(patience=args.stop_patience, verbose=1)
plateau = ReduceLROnPlateau(patience=args.reduce_patience, verbose=1)
save_name = os.path.join(experiment_path, 'model.weights.h5')
save = ModelCheckpoint(save_name, verbose=1, save_best_only=True,
save_weights_only=True)
# training
model.fit(
sequencers[0],
epochs=args.num_epochs,
verbose=1,
callbacks=[log, stop, plateau, save],
validation_data=sequencers[1])