From d33ebe6d8623963e8591e1debf0d30d6b103c838 Mon Sep 17 00:00:00 2001 From: XinyuYe-Intel Date: Mon, 5 Dec 2022 10:31:34 +0800 Subject: [PATCH] Added distributed training support for distillation of MobileNetV2. (#166) Signed-off-by: Xinyu Ye --- .../distillation/eager/README.md | 11 +++ .../distillation/eager/main.py | 82 +++++++++++-------- .../distillation/eager/requirements.txt | 1 + 3 files changed, 61 insertions(+), 33 deletions(-) diff --git a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/README.md b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/README.md index 14841061fdc..d449d5f797b 100644 --- a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/README.md +++ b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/README.md @@ -8,4 +8,15 @@ pip install -r requirements.txt python train_without_distillation.py --epochs 200 --lr 0.1 --layers 40 --widen-factor 2 --name WideResNet-40-2 --tensorboard # for distillation of the teacher model WideResNet40-2 to the student model MobileNetV2-0.35 python main.py --epochs 200 --lr 0.02 --name MobileNetV2-0.35-distillation --teacher_model runs/WideResNet-40-2/model_best.pth.tar --tensorboard --seed 9 +``` + +We also supported Distributed Data Parallel training on single node and multi nodes settings for distillation. To use Distributed Data Parallel to speedup training, the bash command needs a small adjustment. +
+For example, bash command will look like the following, where *``* is the address of the master node, it won't be necessary for single node case, *``* is the desired processes to use in current node, for node with GPU, usually set to number of GPUs in this node, for node without GPU and use CPU for training, it's recommended set to 1, *``* is the number of nodes to use, *``* is the rank of the current node, rank starts from 0 to *``*`-1`. +
+Also please note that to use CPU for training in each node with multi nodes settings, argument `--no_cuda` is mandatory. In multi nodes setting, following command needs to be launched in each node, and all the commands should be the same except for *``*, which should be integer from 0 to *``*`-1` assigned to each node. + +```bash +python -m torch.distributed.launch --master_addr= --nproc_per_node= --nnodes= --node_rank= \ + main.py --epochs 200 --lr 0.02 --name MobileNetV2-0.35-distillation --teacher_model runs/WideResNet-40-2/model_best.pth.tar --tensorboard --seed 9 ``` \ No newline at end of file diff --git a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/main.py b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/main.py index e7f4e56888b..3778162d968 100644 --- a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/main.py +++ b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/main.py @@ -10,6 +10,7 @@ import torchvision.datasets as datasets import torchvision.transforms as transforms +from accelerate import Accelerator from wideresnet import WideResNet # used for logging to TensorBoard @@ -60,6 +61,7 @@ help='loss weights of distillation, should be a list of length 2, ' 'and sum to 1.0, first for student targets loss weight, ' 'second for teacher student loss weight.') +parser.add_argument("--no_cuda", action='store_true', help='use cpu for training.') parser.set_defaults(augment=True) def set_seed(seed): @@ -73,10 +75,13 @@ def set_seed(seed): def main(): global args, best_prec1 args, _ = parser.parse_known_args() + accelerator = Accelerator(cpu=args.no_cuda) + best_prec1 = 0 if args.seed is not None: set_seed(args.seed) - if args.tensorboard: configure("runs/%s"%(args.name)) + with accelerator.local_main_process_first(): + if args.tensorboard: configure("runs/%s"%(args.name)) # Data loading code normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], @@ -111,9 +116,9 @@ def main(): student_model = mobilenet.MobileNetV2(num_classes=10, width_mult=0.35) # get the number of model parameters - print('Number of teacher model parameters: {}'.format( + accelerator.print('Number of teacher model parameters: {}'.format( sum([p.data.nelement() for p in teacher_model.parameters()]))) - print('Number of student model parameters: {}'.format( + accelerator.print('Number of student model parameters: {}'.format( sum([p.data.nelement() for p in student_model.parameters()]))) kwargs = {'num_workers': 0, 'pin_memory': True} @@ -125,10 +130,10 @@ def main(): if args.loss_weights[1] > 0: from tqdm import tqdm def get_logits(teacher_model, train_dataset): - print("***** Getting logits of teacher model *****") - print(f" Num examples = {len(train_dataset) }") + accelerator.print("***** Getting logits of teacher model *****") + accelerator.print(f" Num examples = {len(train_dataset) }") logits_file = os.path.join(os.path.dirname(args.teacher_model), 'teacher_logits.npy') - if not os.path.exists(logits_file): + if not os.path.exists(logits_file) and accelerator.is_local_main_process: teacher_model.eval() train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, **kwargs) train_dataloader = tqdm(train_dataloader, desc="Evaluating") @@ -137,8 +142,8 @@ def get_logits(teacher_model, train_dataset): outputs = teacher_model(input) teacher_logits += [x for x in outputs.numpy()] np.save(logits_file, np.array(teacher_logits)) - else: - teacher_logits = np.load(logits_file) + accelerator.wait_for_everyone() + teacher_logits = np.load(logits_file) train_dataset.targets = [{'labels':l, 'teacher_logits':tl} \ for l, tl in zip(train_dataset.targets, teacher_logits)] return train_dataset @@ -153,15 +158,15 @@ def get_logits(teacher_model, train_dataset): # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): - print("=> loading checkpoint '{}'".format(args.resume)) + accelerator.print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] student_model.load_state_dict(checkpoint['state_dict']) - print("=> loaded checkpoint '{}' (epoch {})" + accelerator.print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: - print("=> no checkpoint found at '{}'".format(args.resume)) + accelerator.print("=> no checkpoint found at '{}'".format(args.resume)) # define optimizer optimizer = torch.optim.SGD(student_model.parameters(), args.lr, @@ -169,13 +174,18 @@ def get_logits(teacher_model, train_dataset): weight_decay=args.weight_decay) # cosine learning rate - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, len(train_loader) * args.epochs // accelerator.num_processes + ) + + student_model, teacher_model, train_loader, val_loader, optimizer = \ + accelerator.prepare(student_model, teacher_model, train_loader, val_loader, optimizer) def train_func(model): - return train(train_loader, model, scheduler, distiller, best_prec1) + return train(train_loader, model, scheduler, distiller, best_prec1, accelerator) def eval_func(model): - return validate(val_loader, model, distiller) + return validate(val_loader, model, distiller, accelerator) from neural_compressor.experimental import Distillation, common from neural_compressor.experimental.common.criterion import PyTorchKnowledgeDistillationLoss @@ -194,11 +204,12 @@ def eval_func(model): directory = "runs/%s/"%(args.name) os.makedirs(directory, exist_ok=True) + model._model = accelerator.unwrap_model(model.model) model.save(directory) # change to framework model for further use model = model.model -def train(train_loader, model, scheduler, distiller, best_prec1): +def train(train_loader, model, scheduler, distiller, best_prec1, accelerator): distiller.on_train_begin() for epoch in range(args.start_epoch, args.epochs): """Train for one epoch on the training set""" @@ -222,13 +233,15 @@ def train(train_loader, model, scheduler, distiller, best_prec1): loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits) # measure accuracy and record loss + output = accelerator.gather(output) + target = accelerator.gather(target) prec1 = accuracy(output.data, target, topk=(1,))[0] - losses.update(loss.data.item(), input.size(0)) - top1.update(prec1.item(), input.size(0)) + losses.update(accelerator.gather(loss).sum().data.item(), input.size(0)*accelerator.num_processes) + top1.update(prec1.item(), input.size(0)*accelerator.num_processes) # compute gradient and do SGD step distiller.optimizer.zero_grad() - loss.backward() + accelerator.backward(loss) # loss.backward() distiller.optimizer.step() scheduler.step() @@ -237,7 +250,7 @@ def train(train_loader, model, scheduler, distiller, best_prec1): end = time.time() if i % args.print_freq == 0: - print('Epoch: [{0}][{1}/{2}]\t' + accelerator.print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' @@ -249,19 +262,20 @@ def train(train_loader, model, scheduler, distiller, best_prec1): # remember best prec@1 and save checkpoint is_best = distiller.best_score > best_prec1 best_prec1 = max(distiller.best_score, best_prec1) - save_checkpoint({ - 'epoch': distiller._epoch_runned + 1, - 'state_dict': model.state_dict(), - 'best_prec1': best_prec1, - }, is_best) - # log to TensorBoard - if args.tensorboard: - log_value('train_loss', losses.avg, epoch) - log_value('train_acc', top1.avg, epoch) - log_value('learning_rate', scheduler._last_lr[0], epoch) + if accelerator.is_local_main_process: + save_checkpoint({ + 'epoch': distiller._epoch_runned + 1, + 'state_dict': model.state_dict(), + 'best_prec1': best_prec1, + }, is_best) + # log to TensorBoard + if args.tensorboard: + log_value('train_loss', losses.avg, epoch) + log_value('train_acc', top1.avg, epoch) + log_value('learning_rate', scheduler._last_lr[0], epoch) -def validate(val_loader, model, distiller): +def validate(val_loader, model, distiller, accelerator): """Perform validation on the validation set""" batch_time = AverageMeter() top1 = AverageMeter() @@ -276,6 +290,8 @@ def validate(val_loader, model, distiller): output = model(input) # measure accuracy + output = accelerator.gather(output) + target = accelerator.gather(target) prec1 = accuracy(output.data, target, topk=(1,))[0] top1.update(prec1.item(), input.size(0)) @@ -284,15 +300,15 @@ def validate(val_loader, model, distiller): end = time.time() if i % args.print_freq == 0: - print('Test: [{0}/{1}]\t' + accelerator.print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, top1=top1)) - print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) + accelerator.print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) # log to TensorBoard - if args.tensorboard: + if accelerator.is_local_main_process and args.tensorboard: log_value('val_acc', top1.avg, distiller._epoch_runned) return top1.avg diff --git a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/requirements.txt b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/requirements.txt index 8db2f310ef5..71252629880 100644 --- a/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/requirements.txt +++ b/examples/pytorch/image_recognition/MobileNetV2-0.35/distillation/eager/requirements.txt @@ -2,3 +2,4 @@ torch==1.5.0+cpu torchvision==0.6.0+cpu tensorboard_logger +accelerate \ No newline at end of file