diff --git a/configs/distill/mmcls/dfnd/README.md b/configs/distill/mmcls/dfnd/README.md new file mode 100644 index 000000000..cad3885cd --- /dev/null +++ b/configs/distill/mmcls/dfnd/README.md @@ -0,0 +1,31 @@ +# Learning Student Networks in the Wild (DFND) + +> [Learning Student Networks in the Wild](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf) + + + +## Abstract + +Data-free learning for student networks is a new paradigm for solving users’ anxiety caused by the privacy problem of using original training data. Since the architectures of modern convolutional neural networks (CNNs) are compact and sophisticated, the alternative images or meta-data generated from the teacher network are often broken. Thus, the student network cannot achieve the comparable performance to that of the pre-trained teacher network especially on the large-scale image dataset. Different to previous works, we present to maximally utilize the massive available unlabeled data in the wild. Specifically, we first thoroughly analyze the output differences between teacher and student network on the original data and develop a data collection method. Then, a noisy knowledge distillation algorithm is proposed for achieving the performance of the student network. In practice, an adaptation matrix is learned with the student network for correcting the label noise produced by the teacher network on the collected unlabeled images. The effectiveness of our DFND (DataFree Noisy Distillation) method is then verified on several benchmarks to demonstrate its superiority over state-of-theart data-free distillation methods. Experiments on various datasets demonstrate that the student networks learned by the proposed method can achieve comparable performance with those using the original dataset. + +pipeline + +## Results and models + +### Classification + +| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | | +| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 94.78 | 95.34 | 94.82 | [config](./dfnd_logits_resnet34_resnet18_8xb32_cifar10.py) | [student](https://drive.google.com/file/d/1_MekfTkCsEl68meWPqtdNZIxdJO2R2Eb/view?usp=drive_link) | + +## Citation + +```latex +@inproceedings{chen2021learning, + title={Learning student networks in the wild}, + author={Chen, Hanting and Guo, Tianyu and Xu, Chang and Li, Wenshuo and Xu, Chunjing and Xu, Chao and Wang, Yunhe}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={6428--6437}, + year={2021} +} +``` diff --git a/configs/distill/mmcls/dfnd/dfnd.PNG b/configs/distill/mmcls/dfnd/dfnd.PNG new file mode 100644 index 000000000..070c33653 Binary files /dev/null and b/configs/distill/mmcls/dfnd/dfnd.PNG differ diff --git a/configs/distill/mmcls/dfnd/dfnd_logits_resnet34_resnet18_8xb32_cifar10.py b/configs/distill/mmcls/dfnd/dfnd_logits_resnet34_resnet18_8xb32_cifar10.py new file mode 100644 index 000000000..39633f37d --- /dev/null +++ b/configs/distill/mmcls/dfnd/dfnd_logits_resnet34_resnet18_8xb32_cifar10.py @@ -0,0 +1,100 @@ +_base_ = ['mmcls::_base_/default_runtime.py'] + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)) +# learning policy +param_scheduler = dict( + type='MultiStepLR', by_epoch=True, milestones=[320, 640], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=800, val_interval=1) +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=128) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomResizedCrop', scale=32), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackClsInputs'), +] + +train_dataloader = dict( + batch_size=256, + num_workers=5, + dataset=dict( + type='ImageNet', + data_root='/cache/data/imagenet/', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), +) + +test_pipeline = [ + dict(type='PackClsInputs'), +] + +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type='CIFAR10', + data_prefix='/cache/data/cifar', + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +val_evaluator = dict(type='Accuracy', topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +teacher_ckpt = '/cache/models/resnet_model.pth' # noqa: E501 + +model = dict( + _scope_='mmrazor', + type='DFNDDistill', + calculate_student_loss=False, + data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + bgr_to_rgb=True), + val_data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # convert image from BGR to RGB + bgr_to_rgb=False), + architecture=dict( + cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False), + teacher=dict( + cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=False), + teacher_ckpt=teacher_ckpt, + distiller=dict( + type='ConfigurableDistiller', + student_recorders=dict( + fc=dict(type='ModuleOutputs', source='head.fc')), + teacher_recorders=dict( + fc=dict(type='ModuleOutputs', source='head.fc')), + distill_losses=dict( + loss_kl=dict( + type='DFNDLoss', + tau=4, + loss_weight=1, + num_classes=10, + batch_select=0.5)), + loss_forward_mappings=dict( + loss_kl=dict( + preds_S=dict(from_student=True, recorder='fc'), + preds_T=dict(from_student=False, recorder='fc'))))) + +find_unused_parameters = True + +val_cfg = dict(type='mmrazor.DFNDValLoop') diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 5fe2fd524..7f560db99 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .autoslim_greedy_search_loop import AutoSlimGreedySearchLoop from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop -from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop +from .distill_val_loop import (DFNDValLoop, SelfDistillValLoop, + SingleTeacherDistillValLoop) from .evolution_search_loop import EvolutionSearchLoop from .iteprune_val_loop import ItePruneValLoop from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, @@ -15,5 +16,5 @@ 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop', - 'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop' + 'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop', 'DFNDValLoop' ] diff --git a/mmrazor/engine/runner/distill_val_loop.py b/mmrazor/engine/runner/distill_val_loop.py index 0a86bbf4e..60a794f3d 100644 --- a/mmrazor/engine/runner/distill_val_loop.py +++ b/mmrazor/engine/runner/distill_val_loop.py @@ -125,3 +125,38 @@ def run(self): self.runner.call_hook('after_val_epoch', metrics=student_metrics) self.runner.call_hook('after_val') + + +@LOOPS.register_module() +class DFNDValLoop(SingleTeacherDistillValLoop): + """Validation loop for DFND. DFND requires different dataset for training + and validation. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 validation. Defaults to + False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False) -> None: + super().__init__(runner, dataloader, evaluator, fp16) + if self.runner.distributed: + assert hasattr(self.runner.model.module, 'teacher') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.module.val_data_preprocessor + self.teacher = self.runner.model.module.teacher + self.teacher.data_preprocessor = data_preprocessor + + else: + assert hasattr(self.runner.model, 'teacher') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.val_data_preprocessor + self.teacher = self.runner.model.teacher + self.teacher.data_preprocessor = data_preprocessor diff --git a/mmrazor/models/algorithms/distill/configurable/__init__.py b/mmrazor/models/algorithms/distill/configurable/__init__.py index 8902f737c..437703477 100644 --- a/mmrazor/models/algorithms/distill/configurable/__init__.py +++ b/mmrazor/models/algorithms/distill/configurable/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .datafree_distillation import (DAFLDataFreeDistillation, DataFreeDistillation) +from .dfnd_distill import DFNDDistill from .fpn_teacher_distill import FpnTeacherDistill from .overhaul_feature_distillation import OverhaulFeatureDistillation from .self_distill import SelfDistill @@ -9,5 +10,5 @@ __all__ = [ 'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', - 'OverhaulFeatureDistillation' + 'OverhaulFeatureDistillation', 'DFNDDistill' ] diff --git a/mmrazor/models/algorithms/distill/configurable/dfnd_distill.py b/mmrazor/models/algorithms/distill/configurable/dfnd_distill.py new file mode 100644 index 000000000..bd8cc0847 --- /dev/null +++ b/mmrazor/models/algorithms/distill/configurable/dfnd_distill.py @@ -0,0 +1,198 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.runner import load_checkpoint +from mmengine.structures import BaseDataElement +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.utils import add_prefix +from mmrazor.registry import MODELS +from ...base import BaseAlgorithm, LossResults + + +@MODELS.register_module() +class DFNDDistill(BaseAlgorithm): + """``DFNDDistill`` algorithm for training student model in the wild dataset. + https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf + + Args: + distiller (dict): The config dict for built distiller. + teacher (dict | BaseModel): The config dict for teacher model or built + teacher model. + val_data_preprocessor (Union[Dict, nn.Module]): Data preprocessor for + evaluation dataset. Defaults to None. + teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None. + teacher_trainable (bool): Whether the teacher is trainable. Defaults + to False. + teacher_norm_eval (bool): Whether to set teacher's norm layers to eval + mode, namely, freeze running stats (mean and var). Note: Effect on + Batch Norm and its variants only. Defaults to True. + student_trainable (bool): Whether the student is trainable. Defaults + to True. + calculate_student_loss (bool): Whether to calculate student loss + (original task loss) to update student model. Defaults to True. + teacher_module_inplace(bool): Whether to allow teacher module inplace + attribute True. Defaults to False. + """ + + def __init__(self, + distiller: dict, + teacher: Union[BaseModel, Dict], + val_data_preprocessor: Optional[Union[Dict, + nn.Module]] = None, + teacher_ckpt: Optional[str] = None, + teacher_trainable: bool = False, + teacher_norm_eval: bool = True, + student_trainable: bool = True, + calculate_student_loss: bool = True, + teacher_module_inplace: bool = False, + **kwargs) -> None: + super().__init__(**kwargs) + + self.distiller = MODELS.build(distiller) + + if isinstance(teacher, Dict): + teacher = MODELS.build(teacher) + + if not isinstance(teacher, BaseModel): + raise TypeError('teacher should be a `dict` or ' + f'`BaseModel` instance, but got ' + f'{type(teacher)}') + + self.teacher = teacher + + # Find all nn.Modules in the model that contain the 'inplace' attribute + # and set them to False. + self.teacher_module_inplace = teacher_module_inplace + if not self.teacher_module_inplace: + self.set_module_inplace_false(teacher, 'self.teacher') + + if teacher_ckpt: + _ = load_checkpoint(self.teacher, teacher_ckpt) + # avoid loaded parameters be overwritten + self.teacher._is_init = True + self.teacher_trainable = teacher_trainable + if not self.teacher_trainable: + for param in self.teacher.parameters(): + param.requires_grad = False + self.teacher_norm_eval = teacher_norm_eval + + # The student model will not calculate gradients and update parameters + # in some pretraining process. + self.student_trainable = student_trainable + + # The student loss will not be updated into ``losses`` in some + # pretraining process. + self.calculate_student_loss = calculate_student_loss + + # In ``ConfigurableDistller``, the recorder manager is just + # constructed, but not really initialized yet. + self.distiller.prepare_from_student(self.student) + self.distiller.prepare_from_teacher(self.teacher) + + # may be modified by stop distillation hook + self.distillation_stopped = False + if val_data_preprocessor is None: + val_data_preprocessor = dict(type='BaseDataPreprocessor') + if isinstance(val_data_preprocessor, nn.Module): + self.val_data_preprocessor = val_data_preprocessor + elif isinstance(val_data_preprocessor, dict): + self.val_data_preprocessor = MODELS.build(val_data_preprocessor) + else: + raise TypeError('val_data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(val_data_preprocessor)}') + + @property + def student(self) -> nn.Module: + """Alias for ``architecture``.""" + return self.architecture + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + + losses = dict() + + # If the `override_data` of a delivery is False, the delivery will + # record the origin data. + self.distiller.set_deliveries_override(False) + if self.teacher_trainable: + with self.distiller.teacher_recorders, self.distiller.deliveries: + teacher_losses = self.teacher( + batch_inputs, data_samples, mode='loss') + + losses.update(add_prefix(teacher_losses, 'teacher')) + else: + with self.distiller.teacher_recorders, self.distiller.deliveries: + with torch.no_grad(): + _ = self.teacher(batch_inputs, data_samples, mode='tensor') + + # If the `override_data` of a delivery is True, the delivery will + # override the origin data with the recorded data. + self.distiller.set_deliveries_override(True) + # Original task loss will not be used during some pretraining process. + if self.calculate_student_loss: + with self.distiller.student_recorders, self.distiller.deliveries: + student_losses = self.student( + batch_inputs, data_samples, mode='loss') + losses.update(add_prefix(student_losses, 'student')) + else: + with self.distiller.student_recorders, self.distiller.deliveries: + if self.student_trainable: + _ = self.student(batch_inputs, data_samples, mode='tensor') + else: + with torch.no_grad(): + _ = self.student( + batch_inputs, data_samples, mode='tensor') + + if not self.distillation_stopped: + # Automatically compute distill losses based on + # `loss_forward_mappings`. + # The required data already exists in the recorders. + distill_losses = self.distiller.compute_distill_losses() + losses.update(add_prefix(distill_losses, 'distill')) + + return losses + + def train(self, mode: bool = True) -> None: + """Set distiller's forward mode.""" + super().train(mode) + if mode and self.teacher_norm_eval: + for m in self.teacher.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def val_step(self, data: Union[tuple, dict, list]) -> list: + """Gets the predictions of given data. + + Calls ``self.val_data_preprocessor(data, False)`` and + ``self(inputs, data_sample, mode='predict')`` in order. Return the + predictions which will be passed to evaluator. + + Args: + data (dict or tuple or list): Data sampled from dataset. + + Returns: + list: The predictions of given data. + """ + data = self.val_data_preprocessor(data, False) + return self._run_forward(data, mode='predict') # type: ignore + + def test_step(self, data: Union[dict, tuple, list]) -> list: + """``BaseModel`` implements ``test_step`` the same as ``val_step``. + + Args: + data (dict or tuple or list): Data sampled from dataset. + + Returns: + list: The predictions of given data. + """ + data = self.val_data_preprocessor(data, False) + return self._run_forward(data, mode='predict') # type: ignore diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 65e2108fd..4dd2e1f99 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -6,6 +6,7 @@ from .cwd import ChannelWiseDivergence from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss from .decoupled_kd import DKDLoss +from .dfnd_loss import DFNDLoss from .dist_loss import DISTLoss from .factor_transfer_loss import FTLoss from .fbkd_loss import FBKDLoss @@ -24,5 +25,5 @@ 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss', 'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss', 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss', - 'DISTLoss' + 'DISTLoss', 'DFNDLoss' ] diff --git a/mmrazor/models/losses/dfnd_loss.py b/mmrazor/models/losses/dfnd_loss.py new file mode 100644 index 000000000..85a40276e --- /dev/null +++ b/mmrazor/models/losses/dfnd_loss.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class DFNDLoss(nn.Module): + """Loss function for DFND. + https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf + + Args: + tau (float): Temperature coefficient. Defaults to 1.0. + reduction (str): Specifies the reduction to apply to the loss: + ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. + ``'none'``: no reduction will be applied, + ``'batchmean'``: the sum of the output will be divided by + the batchsize, + ``'sum'``: the output will be summed, + ``'mean'``: the output will be divided by the number of + elements in the output. + Default: ``'batchmean'`` + loss_weight (float): Weight of loss. Defaults to 1.0. + teacher_detach (bool): Whether to detach the teacher model prediction. + Will set to ``'False'`` in some data-free distillation algorithms. + Defaults to True. + num_classes (int): Number of classes. + teacher_acc (float): The performance of teacher network in the target + dataset. + batch_select (float): ratio of data in the wild dataset to participate + in training. + """ + + def __init__( + self, + tau: float = 1.0, + reduction: str = 'batchmean', + loss_weight: float = 1.0, + teacher_detach: bool = True, + num_classes: int = 1000, + teacher_acc: float = 0.95, + batch_select: float = 0.5, + ): + super(DFNDLoss, self).__init__() + self.tau = tau + self.loss_weight = loss_weight + self.teacher_detach = teacher_detach + + accept_reduction = {'none', 'batchmean', 'sum', 'mean'} + assert reduction in accept_reduction, \ + f'KLDivergence supports reduction {accept_reduction}, ' \ + f'but gets {reduction}.' + self.reduction = reduction + self.noisy_adaptation = torch.nn.Parameter( + torch.zeros(num_classes, num_classes - 1)) + self.teacher_acc = teacher_acc + self.num_classes = num_classes + self.nll_loss = torch.nn.NLLLoss() + self.ce_loss = torch.nn.CrossEntropyLoss(reduction='none') + self.batch_select = batch_select + + def noisy(self): + noise_adaptation_softmax = torch.nn.functional.softmax( + self.noisy_adaptation, dim=1) * (1 - self.teacher_acc) + noise_adaptation_layer = torch.zeros(self.num_classes, + self.num_classes).to( + self.noisy_adaptation.device) + tc = torch.FloatTensor([self.teacher_acc + ]).to(noise_adaptation_softmax.device) + for i in range(self.num_classes): + if i == 0: + noise_adaptation_layer[i] = \ + torch.cat([tc, noise_adaptation_softmax[i][i:]]) + if i == self.num_classes - 1: + noise_adaptation_layer[i] = \ + torch.cat([noise_adaptation_softmax[i][:i], tc]) + else: + noise_adaptation_layer[i] = \ + torch.cat([noise_adaptation_softmax[i][:i], tc, + noise_adaptation_softmax[i][i:]]) + return noise_adaptation_layer + + def forward(self, preds_S, preds_T): + """Forward computation. + + Args: + preds_S (torch.Tensor): The student model prediction with + shape (N, C, H, W) or shape (N, C). + preds_T (torch.Tensor): The teacher model prediction with + shape (N, C, H, W) or shape (N, C). + + Return: + torch.Tensor: The calculated loss value. + """ + if self.teacher_detach: + preds_T = preds_T.detach() + pred = preds_T.data.max(1)[1] + loss_t = self.ce_loss(preds_T, pred) + positive_loss_idx = loss_t.topk( + int(self.batch_select * preds_S.shape[0]), largest=False)[1] + softmax_pred_T = F.softmax(preds_T / self.tau, dim=1) + log_softmax_preds_S = F.log_softmax(preds_S / self.tau, dim=1) + softmax_preds_S_adaptation = torch.matmul( + F.softmax(preds_S, dim=1), self.noisy()) + loss = (self.tau**2) * ( + torch.sum( + F.kl_div( + log_softmax_preds_S[positive_loss_idx], + softmax_pred_T[positive_loss_idx], + reduction='none')) / preds_S.shape[0]) + loss += self.nll_loss(torch.log(softmax_preds_S_adaptation), pred) + return self.loss_weight * loss