-
Notifications
You must be signed in to change notification settings - Fork 88
/
train.py
384 lines (352 loc) · 18.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import os
import sys
import time
import random
import numpy as np
import copy
import scipy
import pickle
import builtins
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR, MultiStepLR
from config import configurations
from backbone.resnet import *
from backbone.resnet_irse import *
from backbone.mobilefacenet import *
from backbone.resattnet import *
from backbone.resnest import *
from backbone.ghostnet import *
from backbone.mobilenetv3 import *
from backbone.proxylessnas import *
from backbone.efficientnet import *
from backbone.densenet import *
from backbone.rexnetv1 import *
from backbone.mobilenext import *
from backbone.mobilenetv2 import *
from head.metrics import *
from loss.loss import *
from util.utils import *
from dataset.datasets import FaceDataset
from dataset.randaugment import RandAugment
from dataset.utils import *
from tensorboardX import SummaryWriter
from tqdm import tqdm
import apex
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
from util.flops_counter import *
from optimizer.lr_scheduler import *
from optimizer.optimizer import *
#from torchprofile import profile_macs
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main():
cfg = configurations[1]
ngpus_per_node = len(cfg['GPU'])
world_size = cfg['WORLD_SIZE']
cfg['WORLD_SIZE'] = ngpus_per_node * world_size
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, cfg))
def main_worker(gpu, ngpus_per_node, cfg):
cfg['GPU'] = gpu
SEED = cfg['SEED'] # random seed for reproduce results
set_seed(SEED)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
if gpu != 0:
def print_pass(*args):
pass
builtins.print = print_pass
cfg['RANK'] = cfg['RANK'] * ngpus_per_node + gpu
dist.init_process_group(backend=cfg['DIST_BACKEND'], init_method = cfg["DIST_URL"], world_size=cfg['WORLD_SIZE'], rank=cfg['RANK'])
# Data loading code
batch_size = int(cfg['BATCH_SIZE'])
per_batch_size = int(batch_size / ngpus_per_node)
#workers = int((cfg['NUM_WORKERS'] + ngpus_per_node - 1) / ngpus_per_node) # dataload threads
workers = int(cfg['NUM_WORKERS'])
DATA_ROOT = cfg['DATA_ROOT'] # the parent root where your train/val/test data are stored
VAL_DATA_ROOT = cfg['VAL_DATA_ROOT']
RECORD_DIR = cfg['RECORD_DIR']
RGB_MEAN = cfg['RGB_MEAN'] # for normalize inputs
RGB_STD = cfg['RGB_STD']
DROP_LAST = cfg['DROP_LAST']
OPTIMIZER = cfg['OPTIMIZER']
LR_SCHEDULER = cfg['LR_SCHEDULER']
LR_STEP_SIZE = cfg['LR_STEP_SIZE']
LR_DECAY_EPOCH = cfg['LR_DECAY_EPOCH']
LR_DECAT_GAMMA = cfg['LR_DECAT_GAMMA']
LR_END = cfg['LR_END']
WARMUP_EPOCH = cfg['WARMUP_EPOCH']
WARMUP_LR = cfg['WARMUP_LR']
NUM_EPOCH = cfg['NUM_EPOCH']
USE_APEX = cfg['USE_APEX']
EVAL_FREQ = cfg['EVAL_FREQ']
SYNC_BN = cfg['SYNC_BN']
print("=" * 60)
print("Overall Configurations:")
print(cfg)
print("=" * 60)
transform_list = [transforms.RandomHorizontalFlip(),]
if cfg['COLORJITTER']:
transform_list.append(transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4))
if cfg['CUTOUT']:
transform_list.append(Cutout())
transform_list.append(transforms.ToTensor())
transform_list.append(transforms.Normalize(mean = RGB_MEAN,std = RGB_STD))
if cfg['RANDOM_ERASING']:
transform_list.append(transforms.RandomErasing())
train_transform = transforms.Compose(transform_list)
if cfg['RANDAUGMENT']:
train_transform.transforms.insert(0, RandAugment(n=cfg['RANDAUGMENT_N'], m=cfg['RANDAUGMENT_M']))
print("=" * 60)
print(train_transform)
print("Train Transform Generated")
print("=" * 60)
dataset_train = FaceDataset(DATA_ROOT, RECORD_DIR, train_transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=per_batch_size,
shuffle = (train_sampler is None), num_workers=workers,
pin_memory=True, sampler=train_sampler, drop_last=DROP_LAST)
SAMPLE_NUMS = dataset_train.get_sample_num_of_each_class()
NUM_CLASS = len(train_loader.dataset.classes)
print("Number of Training Classes: {}".format(NUM_CLASS))
lfw, cfp_fp, agedb_30, vgg2_fp, lfw_issame, cfp_fp_issame, agedb_30_issame, vgg2_fp_issame = get_val_data(VAL_DATA_ROOT)
#======= model & loss & optimizer =======#
BACKBONE_DICT = {'MobileFaceNet': MobileFaceNet,
'ResNet_50': ResNet_50, 'ResNet_101': ResNet_101, 'ResNet_152': ResNet_152,
'IR_50': IR_50, 'IR_100': IR_100, 'IR_101': IR_101, 'IR_152': IR_152, 'IR_185': IR_185, 'IR_200': IR_200,
'IR_SE_50': IR_SE_50, 'IR_SE_100': IR_SE_100, 'IR_SE_101': IR_SE_101, 'IR_SE_152': IR_SE_152, 'IR_SE_185': IR_SE_185, 'IR_SE_200': IR_SE_200,
'AttentionNet_IR_56': AttentionNet_IR_56,'AttentionNet_IRSE_56': AttentionNet_IRSE_56,'AttentionNet_IR_92': AttentionNet_IR_92,'AttentionNet_IRSE_92': AttentionNet_IRSE_92,
'ResNeSt_50': resnest50, 'ResNeSt_101': resnest101, 'ResNeSt_100': resnest100,
'GhostNet': GhostNet, 'MobileNetV3': MobileNetV3, 'ProxylessNAS': proxylessnas, 'EfficientNet': efficientnet,
'DenseNet': densenet, 'ReXNetV1': ReXNetV1, 'MobileNeXt': MobileNeXt, 'MobileNetV2': MobileNetV2
} #'HRNet_W30': HRNet_W30, 'HRNet_W32': HRNet_W32, 'HRNet_W40': HRNet_W40, 'HRNet_W44': HRNet_W44, 'HRNet_W48': HRNet_W48, 'HRNet_W64': HRNet_W64
BACKBONE_NAME = cfg['BACKBONE_NAME']
INPUT_SIZE = cfg['INPUT_SIZE']
assert INPUT_SIZE == [112, 112]
backbone = BACKBONE_DICT[BACKBONE_NAME](INPUT_SIZE)
print("=" * 60)
print(backbone)
print("{} Backbone Generated".format(BACKBONE_NAME))
print("=" * 60)
HEAD_DICT = {'Softmax': Softmax, 'ArcFace': ArcFace, 'Combined': Combined, 'CosFace': CosFace, 'SphereFace': SphereFace,
'Am_softmax': Am_softmax, 'CurricularFace': CurricularFace, 'ArcNegFace': ArcNegFace, 'SVX': SVXSoftmax,
'AirFace': AirFace,'QAMFace': QAMFace, 'CircleLoss':CircleLoss
}
HEAD_NAME = cfg['HEAD_NAME']
EMBEDDING_SIZE = cfg['EMBEDDING_SIZE'] # feature dimension
head = HEAD_DICT[HEAD_NAME](in_features = EMBEDDING_SIZE, out_features = NUM_CLASS)
print("Params: ", count_model_params(backbone))
print("Flops:", count_model_flops(backbone))
#backbone = backbone.eval()
#print("Flops: ", flops_to_string(2*float(profile_macs(backbone.eval(), torch.randn(1, 3, 112, 112)))))
#backbone = backbone.train()
print("=" * 60)
print(head)
print("{} Head Generated".format(HEAD_NAME))
print("=" * 60)
#--------------------optimizer-----------------------------
if BACKBONE_NAME.find("IR") >= 0:
backbone_paras_only_bn, backbone_paras_wo_bn = separate_irse_bn_paras(backbone) # separate batch_norm parameters from others; do not do weight decay for batch_norm parameters to improve the generalizability
else:
backbone_paras_only_bn, backbone_paras_wo_bn = separate_resnet_bn_paras(backbone) # separate batch_norm parameters from others; do not do weight decay for batch_norm parameters to improve the generalizability
torch.cuda.set_device(cfg['GPU'])
backbone.cuda(cfg['GPU'])
head.cuda(cfg['GPU'])
LR = cfg['LR'] # initial LR
WEIGHT_DECAY = cfg['WEIGHT_DECAY']
MOMENTUM = cfg['MOMENTUM']
params = [{'params': backbone_paras_wo_bn + list(head.parameters()), 'weight_decay': WEIGHT_DECAY},
{'params': backbone_paras_only_bn}]
if OPTIMIZER == 'sgd':
optimizer = optim.SGD(params, lr=LR, momentum=MOMENTUM)
elif OPTIMIZER == 'adam':
optimizer = optim.Adam(params, lr=LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
elif OPTIMIZER == 'lookahead':
base_optimizer = optim.Adam(params, lr=LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
optimizer = Lookahead(optimizer=base_optimizer, k=5, alpha=0.5)
elif OPTIMIZER == 'radam':
optimizer = RAdam(params, lr=LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
elif OPTIMIZER == 'ranger':
optimizer = Ranger(params, lr=LR, alpha=0.5, k=6)
elif OPTIMIZER == 'adamp':
optimizer = AdamP(params, lr=LR, betas=(0.9, 0.999), weight_decay=1e-2)
elif OPTIMIZER == 'sgdp':
optimizer = SGDP(params, lr=LR, weight_decay=1e-5, momentum=0.9, nesterov=True)
if LR_SCHEDULER == 'step':
scheduler = StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_DECAT_GAMMA)
elif LR_SCHEDULER == 'multi_step':
scheduler = MultiStepLR(optimizer, milestones=LR_DECAY_EPOCH, gamma=LR_DECAT_GAMMA)
elif LR_SCHEDULER == 'cosine':
scheduler = CosineWarmupLR(optimizer, batches=len(train_loader), epochs=NUM_EPOCH, base_lr=LR, target_lr=LR_END, warmup_epochs=WARMUP_EPOCH, warmup_lr=WARMUP_LR)
print("=" * 60)
print(optimizer)
print("Optimizer Generated")
print("=" * 60)
# loss
LOSS_NAME = cfg['LOSS_NAME']
LOSS_DICT = {'Softmax' : nn.CrossEntropyLoss(),
'LabelSmooth' : LabelSmoothCrossEntropyLoss(classes=NUM_CLASS),
'Focal' : FocalLoss(),
'HM' : HardMining(),
'Softplus' : nn.Softplus()}
loss = LOSS_DICT[LOSS_NAME].cuda(gpu)
print("=" * 60)
print(loss)
print("{} Loss Generated".format(loss))
print("=" * 60)
#optionally resume from a checkpoint
BACKBONE_RESUME_ROOT = cfg['BACKBONE_RESUME_ROOT'] # the root to resume training from a saved checkpoint
HEAD_RESUME_ROOT = cfg['HEAD_RESUME_ROOT'] # the root to resume training from a saved checkpoint
IS_RESUME = cfg['IS_RESUME']
if IS_RESUME:
print("=" * 60)
if os.path.isfile(BACKBONE_RESUME_ROOT):
print("Loading Backbone Checkpoint '{}'".format(BACKBONE_RESUME_ROOT))
loc = 'cuda:{}'.format(cfg['GPU'])
backbone.load_state_dict(torch.load(BACKBONE_RESUME_ROOT, map_location=loc))
if os.path.isfile(HEAD_RESUME_ROOT):
print("Loading Head Checkpoint '{}'".format(HEAD_RESUME_ROOT))
checkpoint = torch.load(HEAD_RESUME_ROOT, map_location=loc)
cfg['START_EPOCH'] = checkpoint['EPOCH']
head.load_state_dict(checkpoint['HEAD'])
optimizer.load_state_dict(checkpoint['OPTIMIZER'])
del(checkpoint)
else:
print("No Checkpoint Found at '{}' and '{}'. Please Have a Check or Continue to Train from Scratch".format(BACKBONE_RESUME_ROOT, HEAD_RESUME_ROOT))
print("=" * 60)
ori_backbone = copy.deepcopy(backbone)
if SYNC_BN:
backbone = apex.parallel.convert_syncbn_model(backbone)
if USE_APEX:
[backbone, head], optimizer = amp.initialize([backbone, head], optimizer, opt_level='O2')
backbone = DDP(backbone)
head = DDP(head)
else:
backbone = torch.nn.parallel.DistributedDataParallel(backbone, device_ids=[cfg['GPU']])
head = torch.nn.parallel.DistributedDataParallel(head, device_ids=[cfg['GPU']])
# checkpoint and tensorboard dir
MODEL_ROOT = cfg['MODEL_ROOT'] # the root to buffer your checkpoints
LOG_ROOT = cfg['LOG_ROOT'] # the root to log your train/val status
os.makedirs(MODEL_ROOT, exist_ok=True)
os.makedirs(LOG_ROOT, exist_ok=True)
writer = SummaryWriter(LOG_ROOT) # writer for buffering intermedium results
# train
for epoch in range(cfg['START_EPOCH'], cfg['NUM_EPOCH']):
train_sampler.set_epoch(epoch)
if LR_SCHEDULER != 'cosine':
scheduler.step()
#train for one epoch
DISP_FREQ = 100 # 100 batch
batch = 0 # batch index
backbone.train() # set to training mode
head.train()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
for inputs, labels in tqdm(iter(train_loader)):
if LR_SCHEDULER == 'cosine':
scheduler.step()
# compute output
start_time=time.time()
inputs = inputs.cuda(cfg['GPU'], non_blocking=True)
labels = labels.cuda(cfg['GPU'], non_blocking=True)
if cfg['MIXUP']:
inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, cfg['GPU'], cfg['MIXUP_PROB'], cfg['MIXUP_ALPHA'])
inputs, labels_a, labels_b = map(Variable, (inputs, labels_a, labels_b))
elif cfg['CUTMIX']:
inputs, labels_a, labels_b, lam = cutmix_data(inputs, labels, cfg['GPU'], cfg['CUTMIX_PROB'], cfg['MIXUP_ALPHA'])
inputs, labels_a, labels_b = map(Variable, (inputs, labels_a, labels_b))
features = backbone(inputs)
outputs = head(features, labels)
if cfg['MIXUP'] or cfg['CUTMIX']:
lossx = mixup_criterion(loss, outputs, labels_a, labels_b, lam)
else:
lossx = loss(outputs, labels) if HEAD_NAME != 'CircleLoss' else loss(outputs).mean()
end_time = time.time()
duration = end_time - start_time
if ((batch + 1) % DISP_FREQ == 0) and batch != 0:
print("batch inference time", duration)
# compute gradient and do SGD step
optimizer.zero_grad()
if USE_APEX:
with amp.scale_loss(lossx, optimizer) as scaled_loss:
scaled_loss.backward()
else:
lossx.backward()
optimizer.step()
# measure accuracy and record loss
prec1, prec5 = accuracy(outputs.data, labels, topk = (1, 5)) if HEAD_NAME != 'CircleLoss' else accuracy(features.data, labels, topk = (1, 5))
losses.update(lossx.data.item(), inputs.size(0))
top1.update(prec1.data.item(), inputs.size(0))
top5.update(prec5.data.item(), inputs.size(0))
# dispaly training loss & acc every DISP_FREQ
if ((batch + 1) % DISP_FREQ == 0) or batch == 0:
print("=" * 60)
print('Epoch {}/{} Batch {}/{}\t'
'Training Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Training Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Training Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch + 1, cfg['NUM_EPOCH'], batch + 1, len(train_loader), loss = losses, top1 = top1, top5 = top5))
print("=" * 60)
# perform validation & save checkpoints per epoch
# validation statistics per epoch (buffer for visualization)
if (batch + 1) % EVAL_FREQ == 0:
#lr = scheduler.get_last_lr()
lr = optimizer.param_groups[0]['lr']
print("Current lr", lr)
print("=" * 60)
print("Perform Evaluation on LFW, CFP_FP, AgeD and VGG2_FP, and Save Checkpoints...")
accuracy_lfw, best_threshold_lfw, roc_curve_lfw = perform_val(EMBEDDING_SIZE, per_batch_size, backbone, lfw, lfw_issame)
buffer_val(writer, "LFW", accuracy_lfw, best_threshold_lfw, roc_curve_lfw, epoch + 1)
accuracy_cfp_fp, best_threshold_cfp_fp, roc_curve_cfp_fp = perform_val(EMBEDDING_SIZE, per_batch_size, backbone, cfp_fp, cfp_fp_issame)
buffer_val(writer, "CFP_FP", accuracy_cfp_fp, best_threshold_cfp_fp, roc_curve_cfp_fp, epoch + 1)
accuracy_agedb_30, best_threshold_agedb_30, roc_curve_agedb_30 = perform_val(EMBEDDING_SIZE, per_batch_size, backbone, agedb_30, agedb_30_issame)
buffer_val(writer, "AgeDB", accuracy_agedb_30, best_threshold_agedb_30, roc_curve_agedb_30, epoch + 1)
accuracy_vgg2_fp, best_threshold_vgg2_fp, roc_curve_vgg2_fp = perform_val(EMBEDDING_SIZE, per_batch_size, backbone, vgg2_fp, vgg2_fp_issame)
buffer_val(writer, "VGGFace2_FP", accuracy_vgg2_fp, best_threshold_vgg2_fp, roc_curve_vgg2_fp, epoch + 1)
print("Epoch {}/{}, Evaluation: LFW Acc: {}, CFP_FP Acc: {}, AgeDB Acc: {}, VGG2_FP Acc: {}".format(epoch + 1, NUM_EPOCH, accuracy_lfw, accuracy_cfp_fp, accuracy_agedb_30, accuracy_vgg2_fp))
print("=" * 60)
print("=" * 60)
print("Save Checkpoint...")
if cfg['RANK'] % ngpus_per_node == 0:
'''
torch.save(backbone.module.state_dict(), os.path.join(MODEL_ROOT, "Backbone_{}_Epoch_{}_Time_{}_checkpoint.pth".format(BACKBONE_NAME, epoch + 1, get_time())))
save_dict = {'EPOCH': epoch+1,
'HEAD': head.module.state_dict(),
'OPTIMIZER': optimizer.state_dict()}
torch.save(save_dict, os.path.join(MODEL_ROOT, "Head_{}_Epoch_{}_Time_{}_checkpoint.pth".format(HEAD_NAME, epoch + 1, get_time())))
'''
ori_backbone.load_state_dict(backbone.module.state_dict())
ori_backbone.eval()
x = torch.randn(1,3,112,112).cuda()
traced_cell = torch.jit.trace(ori_backbone, (x))
torch.jit.save(traced_cell, os.path.join(MODEL_ROOT, "Epoch_{}_Time_{}_checkpoint.pth".format(epoch + 1, get_time())))
sys.stdout.flush()
batch += 1 # batch index
epoch_loss = losses.avg
epoch_acc = top1.avg
print("=" * 60)
print('Epoch: {}/{}\t''Training Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Training Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Training Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch + 1, cfg['NUM_EPOCH'], loss = losses, top1 = top1, top5 = top5))
sys.stdout.flush()
print("=" * 60)
if cfg['RANK'] % ngpus_per_node == 0:
writer.add_scalar("Training_Loss", epoch_loss, epoch + 1)
writer.add_scalar("Training_Accuracy", epoch_acc, epoch + 1)
writer.add_scalar("Top1", top1.avg, epoch+1)
writer.add_scalar("Top5", top5.avg, epoch+1)
if __name__ == '__main__':
main()