diff --git a/configs/distill/mmcls/crd/README.md b/configs/distill/mmcls/crd/README.md new file mode 100644 index 000000000..0f02f365e --- /dev/null +++ b/configs/distill/mmcls/crd/README.md @@ -0,0 +1,30 @@ +# CONTRASTIVE REPRESENTATION DISTILLATION + +> [CONTRASTIVE REPRESENTATION DISTILLATION](https://arxiv.org/abs/1910.10699) + +## Abstract + +Often we wish to transfer representational knowledge from one neural network to another. Examples include distilling a large network into a smaller one, transferring knowledge from one sensory modality to a second, or ensembling a collection of models into a single estimator. Knowledge distillation, the standard approach to these problems, minimizes the KL divergence between the probabilistic outputs of a teacher and student network. We demonstrate that this objective ignores important structural knowledge of the teacher network. This motivates an alternative objective by which we train a student to capture significantly more information in the teacher’s representation of the data. We formulate this objective as contrastive learning. Experiments demonstrate that our resulting new objective outperforms knowledge distillation and other cutting-edge distillers on a variety of knowledge transfer tasks, including single model compression, ensemble distillation, and cross-modal transfer. Our method sets a new state-of-the-art in many transfer tasks, and sometimes even outperforms the teacher network when combined with knowledge distillation.[Original code](http://github.com/HobbitLong/RepDistiller) + +![pipeline](../../../../docs/en/imgs/model_zoo/crd/pipeline.jpg) + +## Citation + +```latex +@article{tian2019contrastive, + title={Contrastive representation distillation}, + author={Tian, Yonglong and Krishnan, Dilip and Isola, Phillip}, + journal={arXiv preprint arXiv:1910.10699}, + year={2019} +} +``` + +## Results and models + +| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | +| ------- | --------- | --------- | --------- | --------- | ------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | +| CIFAR10 | ResNet-18 | ResNet-50 | 94.79 | 99.86 | [config](crd_neck_r50_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth) \|[model](<>) \| [log](<>) | + +## Acknowledgement + +Shout out to @chengshuang18 for his special contribution. diff --git a/configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py b/configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py new file mode 100644 index 000000000..4e36e9a2a --- /dev/null +++ b/configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py @@ -0,0 +1,108 @@ +_base_ = [ + 'mmcls::_base_/datasets/cifar10_bs16.py', + 'mmcls::_base_/schedules/cifar10_bs128.py', + 'mmcls::_base_/default_runtime.py' +] + +model = dict( + _scope_='mmrazor', + type='SingleTeacherDistill', + data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + bgr_to_rgb=True), + architecture=dict( + cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False), + teacher=dict( + cfg_path='mmcls::resnet/resnet50_8xb16_cifar10.py', pretrained=True), + teacher_ckpt='resnet50_b16x8_cifar10_20210528-f54bfad9.pth', + distiller=dict( + type='ConfigurableDistiller', + student_recorders=dict( + neck=dict(type='ModuleOutputs', source='neck.gap'), + data_samples=dict(type='ModuleInputs', source='')), + teacher_recorders=dict( + neck=dict(type='ModuleOutputs', source='neck.gap')), + distill_losses=dict(loss_crd=dict(type='CRDLoss', loss_weight=0.8)), + connectors=dict( + loss_crd_stu=dict(type='CRDConnector', dim_in=512, dim_out=128), + loss_crd_tea=dict(type='CRDConnector', dim_in=2048, dim_out=128)), + loss_forward_mappings=dict( + loss_crd=dict( + s_feats=dict( + from_student=True, + recorder='neck', + connector='loss_crd_stu'), + t_feats=dict( + from_student=False, + recorder='neck', + connector='loss_crd_tea'), + data_samples=dict( + from_student=True, recorder='data_samples', data_idx=1))))) + +find_unused_parameters = True + +val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') + +# change `CIFAR10` dataset to `CRDDataset` dataset. +dataset_type = 'CIFAR10' +train_pipeline = [ + dict(_scope_='mmcls', type='RandomCrop', crop_size=32, padding=4), + dict(_scope_='mmcls', type='RandomFlip', prob=0.5, direction='horizontal'), + dict(_scope_='mmrazor', type='PackCRDClsInputs'), +] + +test_pipeline = [ + dict(_scope_='mmrazor', type='PackCRDClsInputs'), +] + +ori_train_dataset = dict( + _scope_='mmcls', + type=dataset_type, + data_prefix='data/cifar10', + test_mode=False, + pipeline=train_pipeline) + +crd_train_dataset = dict( + _scope_='mmrazor', + type='CRDDataset', + dataset=ori_train_dataset, + neg_num=16384, + sample_mode='exact', + percent=1.0) + +ori_test_dataset = dict( + _scope_='mmcls', + type=dataset_type, + data_prefix='data/cifar10/', + test_mode=True, + pipeline=test_pipeline) + +crd_test_dataset = dict( + _scope_='mmrazor', + type='CRDDataset', + dataset=ori_test_dataset, + neg_num=16384, + sample_mode='exact', + percent=1.0) + +train_dataloader = dict( + _delete_=True, + batch_size=16, + num_workers=2, + dataset=crd_train_dataset, + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + _delete_=True, + batch_size=16, + num_workers=2, + dataset=crd_test_dataset, + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) diff --git a/configs/distill/mmcls/crd/datasets/crd_cifar10_bs16.py b/configs/distill/mmcls/crd/datasets/crd_cifar10_bs16.py new file mode 100644 index 000000000..c7cb74c39 --- /dev/null +++ b/configs/distill/mmcls/crd/datasets/crd_cifar10_bs16.py @@ -0,0 +1,49 @@ +# dataset settings +dataset_type = 'CIFAR10' +preprocess_cfg = dict( + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type='RandomCrop', crop_size=32, padding=4), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackClsInputs'), +] + +test_pipeline = [ + dict(type='PackClsInputs'), +] + +neg_num = 16384 +train_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10', + test_mode=False, + pipeline=train_pipeline, + neg_num=neg_num), + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10/', + test_mode=True, + pipeline=test_pipeline, + neg_num=neg_num), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='Accuracy', topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmrazor/datasets/__init__.py b/mmrazor/datasets/__init__.py new file mode 100644 index 000000000..5cfa79460 --- /dev/null +++ b/mmrazor/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .crd_dataset_wrapper import CRDDataset +from .transforms import PackCRDClsInputs + +__all__ = ['PackCRDClsInputs', 'CRDDataset'] diff --git a/mmrazor/datasets/crd_dataset_wrapper.py b/mmrazor/datasets/crd_dataset_wrapper.py new file mode 100644 index 000000000..308bc1e4c --- /dev/null +++ b/mmrazor/datasets/crd_dataset_wrapper.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from typing import Any, Dict, List, Union + +import numpy as np +from mmengine.dataset.base_dataset import BaseDataset, force_full_init + +from mmrazor.registry import DATASETS + + +@DATASETS.register_module() +class CRDDataset: + """A wrapper of `CRD` dataset. + + Suitable for image classification datasets like CIFAR. Following + the sampling strategy in the `paper `_, + in each epoch, each data sample has contrast information. + Contrast information for an image is indices of negetive data samples. + Note: + ``CRDDataset`` should not inherit from ``BaseDataset`` + since ``get_subset`` and ``get_subset_`` could produce ambiguous + meaning sub-dataset which conflicts with original dataset. If you + want to use a sub-dataset of ``CRDDataset``, you should set + ``indices`` arguments for wrapped dataset which inherit from + ``BaseDataset``. + Args: + dataset (BaseDataset or dict): The dataset to be repeated. + neg_num (int): number of negetive data samples. + percent (float): sampling percentage. + lazy_init (bool, optional): whether to load annotation during + instantiation. Defaults to False + num_classes (int, optional): Number of classes. Defaults to None. + sample_mode (str, optional): Data sampling mode. Defaults to 'exact'. + """ + + def __init__(self, + dataset: Union[BaseDataset, dict], + neg_num: int, + percent: float, + lazy_init: bool = False, + num_classes: int = None, + sample_mode: str = 'exact') -> None: + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') + self._metainfo = self.dataset.metainfo + + self._fully_initialized = False + + # CRD unique attributes. + self.num_classes = num_classes + self.neg_num = neg_num + self.sample_mode = sample_mode + self.percent = percent + + if not lazy_init: + self.full_init() + + def _parse_fullset_contrast_info(self) -> None: + """parse contrast information of the whole dataset.""" + assert self.sample_mode in [ + 'exact', 'random' + ], ('`sample_mode` must in [`exact`, `random`], ' + f'but get `{self.sample_mode}`') + + # Handle special occasion: + # if dataset's ``CLASSES`` is not list of consecutive integers, + # e.g. [2, 3, 5]. + num_classes: int = self.num_classes # type: ignore + if num_classes is None: + num_classes = len(self.dataset.CLASSES) + + if not self.dataset.test_mode: # type: ignore + # Parse info. + self.gt_labels = self.dataset.get_gt_labels() + self.num_samples: int = self.dataset.__len__() + + self.cls_positive: List[List[int]] = [[] + for _ in range(num_classes) + ] # type: ignore + for i in range(self.num_samples): + self.cls_positive[self.gt_labels[i]].append(i) + + self.cls_negative: List[List[int]] = [[] + for i in range(num_classes) + ] # type: ignore + for i in range(num_classes): # type: ignore + for j in range(num_classes): # type: ignore + if j == i: + continue + self.cls_negative[i].extend(self.cls_positive[j]) + + self.cls_positive = [ + np.asarray(self.cls_positive[i]) + for i in range(num_classes) # type: ignore + ] + self.cls_negative = [ + np.asarray(self.cls_negative[i]) + for i in range(num_classes) # type: ignore + ] + + if 0 < self.percent < 1: + n = int(len(self.cls_negative[0]) * self.percent) + self.cls_negative = [ + np.random.permutation(self.cls_negative[i])[0:n] + for i in range(num_classes) # type: ignore + ] + + self.cls_positive = np.asarray(self.cls_positive) + self.cls_negative = np.asarray(self.cls_negative) + + @property + def metainfo(self) -> dict: + """Get the meta information of the repeated dataset. + + Returns: + dict: The meta information of repeated dataset. + """ + return copy.deepcopy(self._metainfo) + + def _get_contrast_info(self, data: Dict, idx: int) -> Dict: + """Get contrast information for each data sample.""" + if self.sample_mode == 'exact': + pos_idx = idx + elif self.sample_mode == 'random': + pos_idx = np.random.choice(self.cls_positive[self.gt_labels[idx]], + 1) + pos_idx = pos_idx[0] # type: ignore + else: + raise NotImplementedError(self.sample_mode) + replace = True if self.neg_num > \ + len(self.cls_negative[self.gt_labels[idx]]) else False + neg_idx = np.random.choice( + self.cls_negative[self.gt_labels[idx]], + self.neg_num, + replace=replace) + contrast_sample_idxs = np.hstack((np.asarray([pos_idx]), neg_idx)) + data['contrast_sample_idxs'] = contrast_sample_idxs + return data + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + self._parse_fullset_contrast_info() + + self._fully_initialized = True + + @force_full_init + def get_data_info(self, idx: int) -> Dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + Returns: + dict: The idx-th annotation of the dataset. + """ + data_info = self.dataset.get_data_info(idx) # type: ignore + if not self.dataset.test_mode: # type: ignore + data_info = self._get_contrast_info(data_info, idx) + return data_info + + def prepare_data(self, idx) -> Any: + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + data_info = self.get_data_info(idx) + return self.dataset.pipeline(data_info) + + def __getitem__(self, idx: int) -> dict: + """Get the idx-th image and data information of dataset after + ``self.pipeline``, and ``full_init`` will be called if the dataset has + not been fully initialized. + + During training phase, if ``self.pipeline`` get ``None``, + ``self._rand_another`` will be called until a valid image is fetched or + the maximum limit of refetech is reached. + + Args: + idx (int): The index of self.data_list. + + Returns: + dict: The idx-th image and data information of dataset after + ``self.pipeline``. + """ + # Performing full initialization by calling `__getitem__` will consume + # extra memory. If a dataset is not fully initialized by setting + # `lazy_init=True` and then fed into the dataloader. Different workers + # will simultaneously read and parse the annotation. It will cost more + # time and memory, although this may work. Therefore, it is recommended + # to manually call `full_init` before dataset fed into dataloader to + # ensure all workers use shared RAM from master process. + if not self._fully_initialized: + warnings.warn( + 'Please call `full_init()` method manually to accelerate ' + 'the speed.') + self.full_init() + + if self.dataset.test_mode: + data = self.prepare_data(idx) + if data is None: + raise Exception('Test time pipline should not get `None` ' + 'data_sample') + return data + + for _ in range(self.dataset.max_refetch + 1): + data = self.prepare_data(idx) + # Broken images or random augmentations may cause the returned data + # to be None + if data is None: + idx = self.dataset._rand_another() + continue + return data + + raise Exception( + f'Cannot find valid image after {self.dataset.max_refetch}! ' + 'Please check your image path and pipeline') + + @force_full_init + def __len__(self): + return len(self.dataset) + + def get_subset_(self, indices: Union[List[int], int]) -> None: + """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning + of sub-dataset.""" + raise NotImplementedError( + '`ClassBalancedDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ClassBalancedDataset`.') + + def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': + """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning + of sub-dataset.""" + raise NotImplementedError( + '`ClassBalancedDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ClassBalancedDataset`.') diff --git a/mmrazor/datasets/transforms/__init__.py b/mmrazor/datasets/transforms/__init__.py new file mode 100644 index 000000000..cb1bebc46 --- /dev/null +++ b/mmrazor/datasets/transforms/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .formatting import PackCRDClsInputs + +__all__ = ['PackCRDClsInputs'] diff --git a/mmrazor/datasets/transforms/formatting.py b/mmrazor/datasets/transforms/formatting.py new file mode 100644 index 000000000..d2ba63ddc --- /dev/null +++ b/mmrazor/datasets/transforms/formatting.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + from mmcls.datasets.transforms.formatting import PackClsInputs, to_tensor + from mmcls.structures import ClsDataSample +except ImportError: + from mmrazor.utils import get_placeholder + PackClsInputs = get_placeholder('mmcls') + to_tensor = get_placeholder('mmcls') + ClsDataSample = get_placeholder('mmcls') + +import warnings +from typing import Any, Dict, Generator + +import numpy as np +import torch + +from mmrazor.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class PackCRDClsInputs(PackClsInputs): + + def transform(self, results: Dict) -> Dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_sample' (obj:`ClsDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + packed_results['inputs'] = to_tensor(img) + else: + warnings.warn( + 'Cannot get "img" in the input dict of `PackClsInputs`,' + 'please make sure `LoadImageFromFile` has been added ' + 'in the data pipeline or images have been loaded in ' + 'the dataset.') + + data_sample = ClsDataSample() + if 'gt_label' in results: + gt_label = results['gt_label'] + data_sample.set_gt_label(gt_label) + + if 'sample_idx' in results: + # transfer `sample_idx` to Tensor + self.meta_keys: Generator[Any, None, None] = ( + key for key in self.meta_keys if key != 'sample_idx') + value = results['sample_idx'] + if isinstance(value, int): + value = torch.tensor(value).to(torch.long) + data_sample.set_data(dict(sample_idx=value)) + + if 'contrast_sample_idxs' in results: + value = results['contrast_sample_idxs'] + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(torch.long) + data_sample.set_data(dict(contrast_sample_idxs=value)) + + img_meta = {k: results[k] for k in self.meta_keys if k in results} + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + + return packed_results diff --git a/mmrazor/models/architectures/connectors/__init__.py b/mmrazor/models/architectures/connectors/__init__.py index 962282d64..c12aa60d7 100644 --- a/mmrazor/models/architectures/connectors/__init__.py +++ b/mmrazor/models/architectures/connectors/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .byot_connector import BYOTConnector from .convmodule_connector import ConvModuleConncetor +from .crd_connector import CRDConnector from .factor_transfer_connectors import Paraphraser, Translator from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector from .ofd_connector import OFDTeacherConnector @@ -9,5 +10,5 @@ __all__ = [ 'ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector', 'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector', - 'TorchNNConnector', 'OFDTeacherConnector' + 'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector' ] diff --git a/mmrazor/models/architectures/connectors/crd_connector.py b/mmrazor/models/architectures/connectors/crd_connector.py new file mode 100644 index 000000000..48648c75d --- /dev/null +++ b/mmrazor/models/architectures/connectors/crd_connector.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS +from .base_connector import BaseConnector + + +@MODELS.register_module() +class CRDConnector(BaseConnector): + """Connector with linear layer. + + Args: + dim_in (int, optional): input channels. Defaults to 1024. + dim_out (int, optional): output channels. Defaults to 128. + """ + + def __init__(self, + dim_in: int = 1024, + dim_out: int = 128, + **kwargs) -> None: + super(CRDConnector, self).__init__(**kwargs) + self.linear = nn.Linear(dim_in, dim_out) + self.l2norm = Normalize(2) + + def forward_train(self, x: torch.Tensor) -> torch.Tensor: + x = x.view(x.size(0), -1) + x = self.linear(x) + x = self.l2norm(x) + return x + + +class Normalize(nn.Module): + """normalization layer. + + Args: + power (int, optional): power. Defaults to 2. + """ + + def __init__(self, power: int = 2) -> None: + super(Normalize, self).__init__() + self.power = power + + def forward(self, x: torch.Tensor) -> torch.Tensor: + norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) + out = x.div(norm) + return out diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index c0a751ec8..a145ba914 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ab_loss import ABLoss from .at_loss import ATLoss +from .crd_loss import CRDLoss from .cwd import ChannelWiseDivergence from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss from .decoupled_kd import DKDLoss @@ -18,5 +19,5 @@ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss', 'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss', - 'L1Loss', 'FBKDLoss' + 'L1Loss', 'FBKDLoss', 'CRDLoss' ] diff --git a/mmrazor/models/losses/crd_loss.py b/mmrazor/models/losses/crd_loss.py new file mode 100644 index 000000000..4ca81aaf5 --- /dev/null +++ b/mmrazor/models/losses/crd_loss.py @@ -0,0 +1,271 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Union + +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class CRDLoss(nn.Module): + """Variate CRD Loss, ICLR 2020. + + https://arxiv.org/abs/1910.10699 + Args: + loss_weight (float, optional): loss weight. Defaults to 1.0. + temperature (float, optional): temperature. Defaults to 0.07. + neg_num (int, optional): number of negative samples. Defaults to 16384. + sample_n (int, optional): number of total samples. Defaults to 50000. + dim_out (int, optional): output channels. Defaults to 128. + momentum (float, optional): momentum. Defaults to 0.5. + eps (double, optional): eps. Defaults to 1e-7. + """ + + def __init__(self, + loss_weight: float = 1.0, + temperature=0.07, + neg_num=16384, + sample_n=50000, + dim_out=128, + momentum=0.5, + eps=1e-7): + super().__init__() + self.loss_weight = loss_weight + self.eps = eps + + self.contrast = ContrastMemory(dim_out, sample_n, neg_num, temperature, + momentum) + self.criterion_s_t = ContrastLoss(sample_n, eps=self.eps) + + def forward(self, s_feats, t_feats, data_samples): + input_data = data_samples[0] + assert 'sample_idx' in input_data, \ + 'you should pass a dict with key `sample_idx` in mimic function.' + assert isinstance( + input_data.sample_idx, torch.Tensor + ), f'`sample_idx` must be a tensor, but get {type(input_data.sample_idx)}' # noqa: E501 + + sample_idxs = torch.stack( + [sample.sample_idx for sample in data_samples]) + if 'contrast_sample_idxs' in input_data: + assert isinstance( + input_data.contrast_sample_idxs, torch.Tensor + ), f'`contrast_sample_idxs` must be a tensor, but get {type(input_data.contrast_sample_idxs)}' # noqa: E501 + contrast_sample_idxs = torch.stack( + [sample.contrast_sample_idxs for sample in data_samples]) + else: + contrast_sample_idxs = None + out_s, out_t = self.contrast(s_feats, t_feats, sample_idxs, + contrast_sample_idxs) + s_loss = self.criterion_s_t(out_s) + t_loss = self.criterion_s_t(out_t) + loss = s_loss + t_loss + return loss + + +class ContrastLoss(nn.Module): + """contrastive loss, corresponding to Eq (18) + + Args: + n_data (int): number of data + eps (float, optional): eps. Defaults to 1e-7. + """ + + def __init__(self, n_data: int, eps: float = 1e-7): + super(ContrastLoss, self).__init__() + self.n_data = n_data + self.eps = eps + + def forward(self, x): + bsz = x.shape[0] + m = x.size(1) - 1 + + # noise distribution + Pn = 1 / float(self.n_data) + + # loss for positive pair + P_pos = x.select(1, 0) + log_D1 = torch.div(P_pos, P_pos.add(m * Pn + self.eps)).log_() + + # loss for neg_sample negative pair + P_neg = x.narrow(1, 1, m) + log_D0 = torch.div(P_neg.clone().fill_(m * Pn), + P_neg.add(m * Pn + self.eps)).log_() + + loss = -(log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz + + return loss + + +class ContrastMemory(nn.Module): + """memory buffer that supplies large amount of negative samples. + + https://github.com/HobbitLong/RepDistiller/blob/master/crd/memory.py + + Args: + dim_out (int, optional): output channels. Defaults to 128. + n_sample (int, optional): number of total samples. + Defaults to 50000. + neg_sample (int, optional): number of negative samples. + Defaults to 16384. + T (float, optional): temperature. Defaults to 0.07. + momentum (float, optional): momentum. Defaults to 0.5. + """ + + def __init__(self, + dim_out: int, + n_sample: int, + neg_sample: int, + T: float = 0.07, + momentum: float = 0.5): + super(ContrastMemory, self).__init__() + self.n_sample = n_sample + self.unigrams = torch.ones(self.n_sample) + self.multinomial = AliasMethod(self.unigrams) + # self.multinomial.cuda() + self.neg_sample = neg_sample + + self.register_buffer('params', + torch.tensor([neg_sample, T, -1, -1, momentum])) + stdv = 1. / math.sqrt(dim_out / 3) + self.register_buffer( + 'memory_v1', + torch.rand(n_sample, dim_out).mul_(2 * stdv).add_(-stdv)) + self.register_buffer( + 'memory_v2', + torch.rand(n_sample, dim_out).mul_(2 * stdv).add_(-stdv)) + + def forward(self, + feat_s: torch.Tensor, + feat_t: torch.Tensor, + idx: torch.Tensor, + sample_idx: Union[None, torch.Tensor] = None) -> torch.Tensor: + neg_sample = int(self.params[0].item()) + T = self.params[1].item() + Z_s = self.params[2].item() + Z_t = self.params[3].item() + + momentum = self.params[4].item() + bsz = feat_s.size(0) + n_sample = self.memory_v1.size(0) + dim_out = self.memory_v1.size(1) + + # original score computation + if sample_idx is None: + sample_idx = self.multinomial.draw(bsz * (self.neg_sample + 1))\ + .view(bsz, -1) + sample_idx.select(1, 0).copy_(idx.data) + # sample + weight_s = torch.index_select(self.memory_v1, 0, + sample_idx.view(-1)).detach() + weight_s = weight_s.view(bsz, neg_sample + 1, dim_out) + out_t = torch.bmm(weight_s, feat_t.view(bsz, dim_out, 1)) + out_t = torch.exp(torch.div(out_t, T)) + # sample + weight_t = torch.index_select(self.memory_v2, 0, + sample_idx.view(-1)).detach() + weight_t = weight_t.view(bsz, neg_sample + 1, dim_out) + out_s = torch.bmm(weight_t, feat_s.view(bsz, dim_out, 1)) + out_s = torch.exp(torch.div(out_s, T)) + + # set Z if haven't been set yet + if Z_s < 0: + self.params[2] = out_s.mean() * n_sample + Z_s = self.params[2].clone().detach().item() + print('normalization constant Z_s is set to {:.1f}'.format(Z_s)) + if Z_t < 0: + self.params[3] = out_t.mean() * n_sample + Z_t = self.params[3].clone().detach().item() + print('normalization constant Z_t is set to {:.1f}'.format(Z_t)) + + # compute out_s, out_t + out_s = torch.div(out_s, Z_s).contiguous() + out_t = torch.div(out_t, Z_t).contiguous() + + # update memory + with torch.no_grad(): + l_pos = torch.index_select(self.memory_v1, 0, idx.view(-1)) + l_pos.mul_(momentum) + l_pos.add_(torch.mul(feat_s, 1 - momentum)) + l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) + updated_v1 = l_pos.div(l_norm) + self.memory_v1.index_copy_(0, idx, updated_v1) + + ab_pos = torch.index_select(self.memory_v2, 0, idx.view(-1)) + ab_pos.mul_(momentum) + ab_pos.add_(torch.mul(feat_t, 1 - momentum)) + ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) + updated_v2 = ab_pos.div(ab_norm) + self.memory_v2.index_copy_(0, idx, updated_v2) + + return out_s, out_t + + +class AliasMethod(object): + """ + From: https://hips.seas.harvard.edu/blog/2013/03/03/ + the-alias-method-efficient-sampling-with-many-discrete-outcomes/ + + Args: + probs (torch.Tensor): probility vector. + """ + + def __init__(self, probs: torch.Tensor) -> None: + + if probs.sum() > 1: + probs.div_(probs.sum()) + neg_sample = len(probs) + self.prob = torch.zeros(neg_sample) + self.alias = torch.LongTensor([0] * neg_sample) + + # Sort the data into the outcomes with probabilities + # that are larger and smaller than 1/neg_sample. + smaller = [] + larger = [] + for kk, prob in enumerate(probs): + self.prob[kk] = neg_sample * prob + if self.prob[kk] < 1.0: + smaller.append(kk) + else: + larger.append(kk) + + # Loop though and create little binary mixtures that + # appropriately allocate the larger outcomes over the + # overall uniform mixture. + while len(smaller) > 0 and len(larger) > 0: + small = smaller.pop() + large = larger.pop() + + self.alias[small] = large + self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] + + if self.prob[large] < 1.0: + smaller.append(large) + else: + larger.append(large) + + for last_one in smaller + larger: + self.prob[last_one] = 1 + + def cuda(self): + """To cuda device.""" + self.prob = self.prob.cuda() + self.alias = self.alias.cuda() + + def draw(self, N: int) -> torch.Tensor: + """Draw N samples from multinomial.""" + neg_sample = self.alias.size(0) + + kk = torch.zeros( + N, dtype=torch.long, + device=self.prob.device).random_(0, neg_sample) + prob = self.prob.index_select(0, kk) + alias = self.alias.index_select(0, kk) + # b is whether a random number is greater than q + b = torch.bernoulli(prob) + oq = kk.mul(b.long()) + oj = alias.mul((1 - b).long()) + + return oq + oj diff --git a/mmrazor/utils/placeholder.py b/mmrazor/utils/placeholder.py index 622e687d1..c81979000 100644 --- a/mmrazor/utils/placeholder.py +++ b/mmrazor/utils/placeholder.py @@ -5,8 +5,10 @@ def get_placeholder(string: str) -> object: Args: string (str): the dependency's name, i.e. `mmcls` + Raises: ImportError: raise it when the dependency is not installed properly. + Returns: object: PlaceHolder instance. """ diff --git a/mmrazor/utils/setup_env.py b/mmrazor/utils/setup_env.py index 392658f84..385be8624 100644 --- a/mmrazor/utils/setup_env.py +++ b/mmrazor/utils/setup_env.py @@ -61,6 +61,7 @@ def register_all_modules(init_default_scope: bool = True) -> None: Defaults to True. """ # noqa + import mmrazor.datasets # noqa: F401,F403 import mmrazor.engine # noqa: F401,F403 import mmrazor.models # noqa: F401,F403 import mmrazor.structures # noqa: F401,F403 diff --git a/tests/data/dataset/a/1.JPG b/tests/data/dataset/a/1.JPG new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/dataset/ann.json b/tests/data/dataset/ann.json new file mode 100644 index 000000000..a55539329 --- /dev/null +++ b/tests/data/dataset/ann.json @@ -0,0 +1,28 @@ +{ + "metainfo": { + "categories": [ + { + "category_name": "first", + "id": 0 + }, + { + "category_name": "second", + "id": 1 + } + ] + }, + "data_list": [ + { + "img_path": "a/1.JPG", + "gt_label": 0 + }, + { + "img_path": "b/2.jpeg", + "gt_label": 1 + }, + { + "img_path": "b/subb/2.jpeg", + "gt_label": 1 + } + ] +} diff --git a/tests/data/dataset/ann.txt b/tests/data/dataset/ann.txt new file mode 100644 index 000000000..f929e873b --- /dev/null +++ b/tests/data/dataset/ann.txt @@ -0,0 +1,3 @@ +a/1.JPG 0 +b/2.jpeg 1 +b/subb/3.jpg 1 diff --git a/tests/data/dataset/b/2.jpeg b/tests/data/dataset/b/2.jpeg new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/dataset/b/subb/3.jpg b/tests/data/dataset/b/subb/3.jpg new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/dataset/classes.txt b/tests/data/dataset/classes.txt new file mode 100644 index 000000000..c012a51e6 --- /dev/null +++ b/tests/data/dataset/classes.txt @@ -0,0 +1,2 @@ +bus +car diff --git a/tests/data/dataset/multi_label_ann.json b/tests/data/dataset/multi_label_ann.json new file mode 100644 index 000000000..5cd8a84d0 --- /dev/null +++ b/tests/data/dataset/multi_label_ann.json @@ -0,0 +1,28 @@ +{ + "metainfo": { + "categories": [ + { + "category_name": "first", + "id": 0 + }, + { + "category_name": "second", + "id": 1 + } + ] + }, + "data_list": [ + { + "img_path": "a/1.JPG", + "gt_label": [0] + }, + { + "img_path": "b/2.jpeg", + "gt_label": [1] + }, + { + "img_path": "b/subb/2.jpeg", + "gt_label": [0, 1] + } + ] +} diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py new file mode 100644 index 000000000..1e6031a97 --- /dev/null +++ b/tests/test_datasets/test_datasets.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import pickle +import tempfile +from unittest import TestCase + +import numpy as np + +from mmrazor.registry import DATASETS +from mmrazor.utils import register_all_modules + +register_all_modules() +ASSETS_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '../data/dataset')) + + +class Test_CRD_CIFAR10(TestCase): + DATASET_TYPE = 'CRD_CIFAR10' + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + tmpdir = tempfile.TemporaryDirectory() + cls.tmpdir = tmpdir + data_prefix = tmpdir.name + cls.DEFAULT_ARGS = dict( + data_prefix=data_prefix, pipeline=[], test_mode=False) + + dataset_class = DATASETS.get(cls.DATASET_TYPE) + base_folder = osp.join(data_prefix, dataset_class.base_folder) + os.mkdir(base_folder) + + cls.fake_imgs = np.random.randint( + 0, 255, size=(6, 3 * 32 * 32), dtype=np.uint8) + cls.fake_labels = np.random.randint(0, 10, size=(6, )) + cls.fake_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + batch1 = dict( + data=cls.fake_imgs[:2], labels=cls.fake_labels[:2].tolist()) + with open(osp.join(base_folder, 'data_batch_1'), 'wb') as f: + f.write(pickle.dumps(batch1)) + + batch2 = dict( + data=cls.fake_imgs[2:4], labels=cls.fake_labels[2:4].tolist()) + with open(osp.join(base_folder, 'data_batch_2'), 'wb') as f: + f.write(pickle.dumps(batch2)) + + test_batch = dict( + data=cls.fake_imgs[4:], fine_labels=cls.fake_labels[4:].tolist()) + with open(osp.join(base_folder, 'test_batch'), 'wb') as f: + f.write(pickle.dumps(test_batch)) + + meta = {dataset_class.meta['key']: cls.fake_classes} + meta_filename = dataset_class.meta['filename'] + with open(osp.join(base_folder, meta_filename), 'wb') as f: + f.write(pickle.dumps(meta)) + + dataset_class.train_list = [['data_batch_1', None], + ['data_batch_2', None]] + dataset_class.test_list = [['test_batch', None]] + dataset_class.meta['md5'] = None + + def test_initialize(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test overriding metainfo by `metainfo` argument + cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ('bus', 'car')) + + # Test overriding metainfo by `classes` argument + cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ('bus', 'car')) + + classes_file = osp.join(ASSETS_ROOT, 'classes.txt') + cfg = {**self.DEFAULT_ARGS, 'classes': classes_file} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ('bus', 'car')) + self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1}) + + # Test invalid classes + cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)} + with self.assertRaisesRegex(ValueError, "type "): + dataset_class(**cfg) + + @classmethod + def tearDownClass(cls): + cls.tmpdir.cleanup() + + +class Test_CRD_CIFAR100(Test_CRD_CIFAR10): + DATASET_TYPE = 'CRD_CIFAR100' diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py new file mode 100644 index 000000000..46aa671df --- /dev/null +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +import unittest + +import numpy as np +import torch +from mmcls.structures import ClsDataSample +from mmengine.data import LabelData + +from mmrazor.datasets.transforms import PackCRDClsInputs + + +class TestPackClsInputs(unittest.TestCase): + + def setUp(self): + """Setup the model and optimizer which are used in every test method. + + TestCase calls functions in this order: setUp() -> testMethod() -> + tearDown() -> cleanUp() + """ + data_prefix = osp.join(osp.dirname(__file__), '../../data') + img_path = osp.join(data_prefix, 'color.jpg') + rng = np.random.RandomState(0) + self.results1 = { + 'sample_idx': 1, + 'img_path': img_path, + 'ori_height': 300, + 'ori_width': 400, + 'height': 600, + 'width': 800, + 'scale_factor': 2.0, + 'flip': False, + 'img': rng.rand(300, 400), + 'gt_label': rng.randint(3, ), + # TODO. + 'contrast_sample_idxs': rng.randint() + } + self.meta_keys = ('sample_idx', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip') + + def test_transform(self): + transform = PackCRDClsInputs(meta_keys=self.meta_keys) + results = transform(copy.deepcopy(self.results1)) + self.assertIn('inputs', results) + self.assertIsInstance(results['inputs'], torch.Tensor) + self.assertIn('data_sample', results) + self.assertIsInstance(results['data_sample'], ClsDataSample) + + data_sample = results['data_sample'] + self.assertIsInstance(data_sample.gt_label, LabelData) + + def test_repr(self): + transform = PackCRDClsInputs(meta_keys=self.meta_keys) + self.assertEqual( + repr(transform), f'PackClsInputs(meta_keys={self.meta_keys})') diff --git a/tests/test_models/test_architectures/test_connectors/test_connectors.py b/tests/test_models/test_architectures/test_connectors/test_connectors.py index 56abd0c42..5d44efd75 100644 --- a/tests/test_models/test_architectures/test_connectors/test_connectors.py +++ b/tests/test_models/test_architectures/test_connectors/test_connectors.py @@ -3,7 +3,7 @@ import torch -from mmrazor.models import (BYOTConnector, ConvModuleConncetor, +from mmrazor.models import (BYOTConnector, ConvModuleConncetor, CRDConnector, FBKDStudentConnector, FBKDTeacherConnector, Paraphraser, TorchFunctionalConnector, TorchNNConnector, Translator) @@ -40,6 +40,23 @@ def test_convmodule_connector(self): with self.assertRaises(AssertionError): _ = ConvModuleConncetor(**convmodule_connector_cfg) + def test_crd_connector(self): + dim_out = 128 + crd_stu_connector = CRDConnector( + **dict(dim_in=1 * 5 * 5, dim_out=dim_out)) + + crd_tea_connector = CRDConnector( + **dict(dim_in=3 * 5 * 5, dim_out=dim_out)) + + assert crd_stu_connector.linear.in_features == 1 * 5 * 5 + assert crd_stu_connector.linear.out_features == dim_out + assert crd_tea_connector.linear.in_features == 3 * 5 * 5 + assert crd_tea_connector.linear.out_features == dim_out + + s_output = crd_stu_connector.forward_train(self.s_feat) + t_output = crd_tea_connector.forward_train(self.t_feat) + assert s_output.size() == t_output.size() + def test_ft_connector(self): stu_connector = Translator(**dict(in_channel=1, out_channel=2)) diff --git a/tests/test_models/test_losses/test_distillation_losses.py b/tests/test_models/test_losses/test_distillation_losses.py index c3ab16949..4328f7865 100644 --- a/tests/test_models/test_losses/test_distillation_losses.py +++ b/tests/test_models/test_losses/test_distillation_losses.py @@ -2,11 +2,12 @@ from unittest import TestCase import torch +from mmengine.data import BaseDataElement from mmrazor import digit_version -from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, DKDLoss, FBKDLoss, - FTLoss, InformationEntropyLoss, KDSoftCELoss, - OFDLoss, OnehotLikeLoss) +from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss, + FBKDLoss, FTLoss, InformationEntropyLoss, + KDSoftCELoss, OFDLoss, OnehotLikeLoss) class TestLosses(TestCase): @@ -69,6 +70,34 @@ def test_ab_loss(self): self.normal_test_2d(ab_loss) self.normal_test_3d(ab_loss) + def _mock_crd_data_sample(self, sample_idx_list): + data_samples = [] + for _idx in sample_idx_list: + data_sample = BaseDataElement() + data_sample.set_data(dict(sample_idx=_idx)) + data_samples.append(data_sample) + return data_samples + + def test_crd_loss(self): + crd_loss = CRDLoss(**dict(neg_num=5, sample_n=10, dim_out=6)) + sample_idx_list = torch.tensor(list(range(5))) + data_samples = self._mock_crd_data_sample(sample_idx_list) + loss = crd_loss.forward(self.feats_1d, self.feats_1d, data_samples) + self.assertTrue(loss.numel() == 1) + + # test the calculation + s_feat_0 = torch.randn((5, 6)) + t_feat_0 = torch.randn((5, 6)) + crd_loss_num_0 = crd_loss.forward(s_feat_0, t_feat_0, data_samples) + assert crd_loss_num_0 != torch.tensor(0.0) + + s_feat_1 = torch.randn((5, 6)) + t_feat_1 = torch.rand((5, 6)) + sample_idx_list_1 = torch.tensor(list(range(5))) + data_samples_1 = self._mock_crd_data_sample(sample_idx_list_1) + crd_loss_num_1 = crd_loss.forward(s_feat_1, t_feat_1, data_samples_1) + assert crd_loss_num_1 != torch.tensor(0.0) + def test_dkd_loss(self): dkd_loss_cfg = dict(loss_weight=1.0) dkd_loss = DKDLoss(**dkd_loss_cfg)