diff --git a/.circleci/test.yml b/.circleci/test.yml index 92ac230c9..5da20de36 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -26,7 +26,6 @@ jobs: command: | pip install interrogate interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 80 mmrazor - build_cpu: parameters: # The python version must match available image tags in @@ -37,8 +36,6 @@ jobs: type: string torchvision: type: string - mmcv: - type: string docker: - image: cimg/python:<< parameters.python >> resource_class: large @@ -58,20 +55,21 @@ jobs: name: Install PyTorch command: | python -V - python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html - when: condition: - equal: [ "3.9.0", << parameters.python >> ] + equal: ["3.9.0", << parameters.python >>] steps: - run: pip install "protobuf <= 3.20.1" && sudo apt-get update && sudo apt-get -y install libprotobuf-dev protobuf-compiler cmake - run: name: Install mmrazor dependencies command: | - python -m pip install git+ssh://git@github.com/open-mmlab/mmengine.git@main - python -m pip install << parameters.mmcv >> - python -m pip install git+ssh://git@github.com/open-mmlab/mmclassification.git@dev-1.x - python -m pip install git+ssh://git@github.com/open-mmlab/mmdetection.git@dev-3.x - python -m pip install git+ssh://git@github.com/open-mmlab/mmsegmentation.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmengine.git@main + pip install -U openmim + mim install 'mmcv >= 2.0.0rc1' + pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x pip install -r requirements.txt - run: name: Build and install @@ -80,10 +78,9 @@ jobs: - run: name: Run unittests command: | - python -m coverage run --branch --source mmrazor -m pytest tests/ - python -m coverage xml - python -m coverage report -m - + coverage run --branch --source mmrazor -m pytest tests/ + coverage xml + coverage report -m build_cuda: parameters: torch: @@ -94,8 +91,6 @@ jobs: cudnn: type: integer default: 7 - mmcv: - type: string machine: image: ubuntu-2004-cuda-11.4:202110-01 # docker_layer_caching: true @@ -103,13 +98,13 @@ jobs: steps: - checkout - run: - # CLoning repos in VM since Docker doesn't have access to the private key + # Cloning repos in VM since Docker doesn't have access to the private key name: Clone Repos command: | - git clone -b main --depth 1 ssh://git@github.com/open-mmlab/mmengine.git /home/circleci/mmengine - git clone -b dev-3.x --depth 1 ssh://git@github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection - git clone -b dev-1.x --depth 1 ssh://git@github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification - git clone -b dev-1.x --depth 1 ssh://git@github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation + git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine + git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection + git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification + git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation - run: name: Build Docker image command: | @@ -117,10 +112,10 @@ jobs: docker run --gpus all -t -d -v /home/circleci/project:/mmrazor -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmdetection:/mmdetection -v /home/circleci/mmclassification:/mmclassification -v /home/circleci/mmsegmentation:/mmsegmentation -w /mmrazor --name mmrazor mmrazor:gpu - run: name: Install mmrazor dependencies - # pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch${{matrix.torch_version}}/index.html command: | docker exec mmrazor pip install -e /mmengine - docker exec mmrazor pip install << parameters.mmcv >> + docker exec mmrazor pip install -U openmim + docker exec mmrazor mim install 'mmcv >= 2.0.0rc1' docker exec mmrazor pip install -e /mmdetection docker exec mmrazor pip install -e /mmclassification docker exec mmrazor pip install -e /mmsegmentation @@ -132,7 +127,7 @@ jobs: - run: name: Run unittests command: | - docker exec mmrazor python -m pytest tests/ + docker exec mmrazor pytest tests/ workflows: pr_stage_lint: @@ -144,10 +139,10 @@ workflows: branches: ignore: - dev-1.x + - 1.x pr_stage_test: when: - not: - << pipeline.parameters.lint_only >> + not: << pipeline.parameters.lint_only >> jobs: - lint: name: lint @@ -159,16 +154,14 @@ workflows: name: minimum_version_cpu torch: 1.6.0 torchvision: 0.7.0 - python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images - mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cpu/torch1.6.0/mmcv_full-2.0.0rc1-cp36-cp36m-manylinux1_x86_64.whl + python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images requires: - lint - build_cpu: name: maximum_version_cpu - torch: 1.9.0 - torchvision: 0.10.0 + torch: 1.12.1 + torchvision: 0.13.1 python: 3.9.0 - mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cpu/torch1.9.0/mmcv_full-2.0.0rc1-cp39-cp39-manylinux1_x86_64.whl requires: - minimum_version_cpu - hold: @@ -181,20 +174,17 @@ workflows: # Use double quotation mark to explicitly specify its type # as string instead of number cuda: "10.2" - mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cu102/torch1.8.0/mmcv_full-2.0.0rc1-cp37-cp37m-manylinux1_x86_64.whl requires: - hold merge_stage_test: when: - not: - << pipeline.parameters.lint_only >> + not: << pipeline.parameters.lint_only >> jobs: - build_cuda: name: minimum_version_gpu torch: 1.6.0 # Use double quotation mark to explicitly specify its type # as string instead of number - mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cu101/torch1.6.0/mmcv_full-2.0.0rc1-cp37-cp37m-manylinux1_x86_64.whl cuda: "10.1" filters: branches: diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000..0083e4e13 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,201 @@ +name: build + +on: + push: + paths-ignore: + - "README.md" + - "README_zh-CN.md" + - "model-index.yml" + - "configs/**" + - "docs/**" + - ".dev_scripts/**" + + pull_request: + paths-ignore: + - "README.md" + - "README_zh-CN.md" + - "docs/**" + - "demo/**" + - ".dev_scripts/**" + - ".circleci/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test_linux: + runs-on: ubuntu-18.04 + strategy: + matrix: + python-version: [3.7] + torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0] + include: + - torch: 1.6.0 + torch_version: 1.6 + torchvision: 0.7.0 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + python-version: 3.8 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + python-version: 3.8 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + python-version: 3.8 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + python-version: 3.8 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + python-version: 3.8 + - torch: 1.12.0 + torch_version: 1.12 + torchvision: 0.13.0 + - torch: 1.12.0 + torch_version: 1.12 + torchvision: 0.13.0 + python-version: 3.8 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip + run: | + pip install pip --upgrade + pip install wheel + - name: Install PyTorch + run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install MMEngine + run: pip install git+https://github.com/open-mmlab/mmengine.git@main + - name: Install MMCV + run: | + pip install -U openmim + mim install 'mmcv >= 2.0.0rc1' + - name: Install MMCls + run: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + - name: Install MMDet + run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + - name: Install MMSeg + run: pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x + - name: Install other dependencies + run: pip install -r requirements.txt + - name: Build and install + run: rm -rf .eggs && pip install -e . + - name: Run unittests and generate coverage report + run: | + coverage run --branch --source mmrazor -m pytest tests/ + coverage xml + coverage report -m + # Upload coverage report for python3.8 && pytorch1.12.0 cpu + - name: Upload coverage to Codecov + if: ${{matrix.torch == '1.12.0' && matrix.python-version == '3.8'}} + uses: codecov/codecov-action@v2 + with: + file: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON + name: codecov-umbrella + fail_ci_if_error: false + + + + test_cuda: + runs-on: ubuntu-18.04 + container: + image: pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel + strategy: + matrix: + python-version: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip + run: pip install pip --upgrade + - name: Fetch GPG keys + run: | + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub + - name: Install Python-dev + run: apt-get update && apt-get install -y python${{matrix.python-version}}-dev + if: ${{matrix.python-version != 3.9}} + - name: Install system dependencies + run: | + apt-get update + apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libxrender-dev + - name: Install mmrazor dependencies + run: | + pip install git+https://github.com/open-mmlab/mmengine.git@main + pip install -U openmim + mim install 'mmcv >= 2.0.0rc1' + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x + pip install -r requirements.txt + - name: Build and install + run: | + python setup.py check -m -s + TORCH_CUDA_ARCH_LIST=7.0 pip install -e . + + # test_windows: + # runs-on: ${{ matrix.os }} + # strategy: + # matrix: + # os: [windows-2022] + # python: [3.7] + # platform: [cpu] + # steps: + # - uses: actions/checkout@v2 + # - name: Set up Python ${{ matrix.python-version }} + # uses: actions/setup-python@v2 + # with: + # python-version: ${{ matrix.python-version }} + # - name: Upgrade pip + # run: | + # pip install pip --upgrade + # pip install wheel + # - name: Install lmdb + # run: pip install lmdb + # - name: Install PyTorch + # run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html + # - name: Install mmrazor dependencies + # run: | + # pip install git+https://github.com/open-mmlab/mmengine.git@main + # pip install -U openmim + # mim install 'mmcv >= 2.0.0rc1' + # pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + # pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + # pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x + # pip install -r requirements.txt + # - name: Build and install + # run: | + # pip install -e . + # - name: Run unittests and generate coverage report + # run: | + # pytest tests/ diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..36422d008 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,27 @@ +name: lint + +on: [push, pull_request] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + - name: Install pre-commit hook + run: | + pip install pre-commit + pre-commit install + - name: Linting + run: pre-commit run --all-files + - name: Check docstring coverage + run: | + pip install interrogate + interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmrazor diff --git a/configs/nas/mmcls/darts/metafile.yml b/configs/nas/mmcls/darts/metafile.yml index 4b0515b5b..9594e6765 100644 --- a/configs/nas/mmcls/darts/metafile.yml +++ b/configs/nas/mmcls/darts/metafile.yml @@ -24,5 +24,5 @@ Models: Metrics: Top 1 Accuracy: 97.32 Top 5 Accuracy: 99.94 - Config: configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py + Config: configs/nas/darts/darts_subnet_1xb96_cifar10_2.0.py Weights: https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921.pth diff --git a/mmrazor/datasets/crd_dataset_wrapper.py b/mmrazor/datasets/crd_dataset_wrapper.py index 308bc1e4c..aa62f383b 100644 --- a/mmrazor/datasets/crd_dataset_wrapper.py +++ b/mmrazor/datasets/crd_dataset_wrapper.py @@ -74,7 +74,7 @@ def _parse_fullset_contrast_info(self) -> None: # e.g. [2, 3, 5]. num_classes: int = self.num_classes # type: ignore if num_classes is None: - num_classes = len(self.dataset.CLASSES) + num_classes = max(self.dataset.get_gt_labels()) + 1 if not self.dataset.test_mode: # type: ignore # Parse info. diff --git a/mmrazor/models/task_modules/tracer/parsers.py b/mmrazor/models/task_modules/tracer/parsers.py index efcfc8613..c342da716 100644 --- a/mmrazor/models/task_modules/tracer/parsers.py +++ b/mmrazor/models/task_modules/tracer/parsers.py @@ -118,13 +118,20 @@ def parse_cat(tracer, grad_fn, module2name, param2module, cur_path, >>> # ``out`` is obtained by concatenating two tensors """ parents = grad_fn.next_functions + concat_id = '_'.join([str(id(p)) for p in parents]) + concat_id_list = [str(id(p)) for p in parents] + concat_id_list.sort() + concat_id = '_'.join(concat_id_list) + name = f'concat_{concat_id}' + + visited[name] = True sub_path_lists = list() - for i, parent in enumerate(parents): + for _, parent in enumerate(parents): sub_path_list = PathList() tracer.backward_trace(parent, module2name, param2module, Path(), sub_path_list, visited, shared_module) sub_path_lists.append(sub_path_list) - cur_path.append(PathConcatNode('CatNode', sub_path_lists)) + cur_path.append(PathConcatNode(name, sub_path_lists)) result_paths.append(copy.deepcopy(cur_path)) cur_path.pop(-1) diff --git a/requirements/optional.txt b/requirements/optional.txt index 609cc3925..32f7d6fd0 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,3 @@ albumentations>=0.3.2 scipy -timm +# timm diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py index 1e6031a97..1eaf72ec8 100644 --- a/tests/test_datasets/test_datasets.py +++ b/tests/test_datasets/test_datasets.py @@ -6,6 +6,7 @@ from unittest import TestCase import numpy as np +from mmcls.registry import DATASETS as CLS_DATASETS from mmrazor.registry import DATASETS from mmrazor.utils import register_all_modules @@ -15,7 +16,8 @@ class Test_CRD_CIFAR10(TestCase): - DATASET_TYPE = 'CRD_CIFAR10' + ORI_DATASET_TYPE = 'CIFAR10' + DATASET_TYPE = 'CRDDataset' @classmethod def setUpClass(cls) -> None: @@ -24,10 +26,11 @@ def setUpClass(cls) -> None: tmpdir = tempfile.TemporaryDirectory() cls.tmpdir = tmpdir data_prefix = tmpdir.name - cls.DEFAULT_ARGS = dict( + cls.ORI_DEFAULT_ARGS = dict( data_prefix=data_prefix, pipeline=[], test_mode=False) + cls.DEFAULT_ARGS = dict(neg_num=1, percent=0.5) - dataset_class = DATASETS.get(cls.DATASET_TYPE) + dataset_class = CLS_DATASETS.get(cls.ORI_DATASET_TYPE) base_folder = osp.join(data_prefix, dataset_class.base_folder) os.mkdir(base_folder) @@ -65,25 +68,16 @@ def test_initialize(self): dataset_class = DATASETS.get(self.DATASET_TYPE) # Test overriding metainfo by `metainfo` argument - cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}} + ori_cfg = { + **self.ORI_DEFAULT_ARGS, 'metainfo': { + 'classes': ('bus', 'car') + }, + 'type': self.ORI_DATASET_TYPE, + '_scope_': 'mmcls' + } + cfg = {'dataset': ori_cfg, **self.DEFAULT_ARGS} dataset = dataset_class(**cfg) - self.assertEqual(dataset.CLASSES, ('bus', 'car')) - - # Test overriding metainfo by `classes` argument - cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']} - dataset = dataset_class(**cfg) - self.assertEqual(dataset.CLASSES, ('bus', 'car')) - - classes_file = osp.join(ASSETS_ROOT, 'classes.txt') - cfg = {**self.DEFAULT_ARGS, 'classes': classes_file} - dataset = dataset_class(**cfg) - self.assertEqual(dataset.CLASSES, ('bus', 'car')) - self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1}) - - # Test invalid classes - cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)} - with self.assertRaisesRegex(ValueError, "type "): - dataset_class(**cfg) + self.assertEqual(dataset.dataset.CLASSES, ('bus', 'car')) @classmethod def tearDownClass(cls): @@ -91,4 +85,4 @@ def tearDownClass(cls): class Test_CRD_CIFAR100(Test_CRD_CIFAR10): - DATASET_TYPE = 'CRD_CIFAR100' + ORI_DATASET_TYPE = 'CIFAR100' diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index 46aa671df..69e211aad 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -6,7 +6,7 @@ import numpy as np import torch from mmcls.structures import ClsDataSample -from mmengine.data import LabelData +from mmengine.structures import LabelData from mmrazor.datasets.transforms import PackCRDClsInputs @@ -34,7 +34,7 @@ def setUp(self): 'img': rng.rand(300, 400), 'gt_label': rng.randint(3, ), # TODO. - 'contrast_sample_idxs': rng.randint() + 'contrast_sample_idxs': rng.randint(3, ) } self.meta_keys = ('sample_idx', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'flip') @@ -44,13 +44,13 @@ def test_transform(self): results = transform(copy.deepcopy(self.results1)) self.assertIn('inputs', results) self.assertIsInstance(results['inputs'], torch.Tensor) - self.assertIn('data_sample', results) - self.assertIsInstance(results['data_sample'], ClsDataSample) + self.assertIn('data_samples', results) + self.assertIsInstance(results['data_samples'], ClsDataSample) - data_sample = results['data_sample'] + data_sample = results['data_samples'] self.assertIsInstance(data_sample.gt_label, LabelData) def test_repr(self): transform = PackCRDClsInputs(meta_keys=self.meta_keys) self.assertEqual( - repr(transform), f'PackClsInputs(meta_keys={self.meta_keys})') + repr(transform), f'PackCRDClsInputs(meta_keys={self.meta_keys})') diff --git a/tests/test_models/test_algorithms/test_dsnas.py b/tests/test_models/test_algorithms/test_dsnas.py index 929840148..9f6dfc902 100644 --- a/tests/test_models/test_algorithms/test_dsnas.py +++ b/tests/test_models/test_algorithms/test_dsnas.py @@ -170,19 +170,17 @@ def setUpClass(cls) -> None: os.environ['MASTER_PORT'] = '12345' # initialize the process group - if torch.cuda.is_available(): - backend = 'nccl' - cls.device = 'cuda' - else: - backend = 'gloo' + backend = 'nccl' if torch.cuda.is_available() else 'gloo' dist.init_process_group(backend, rank=0, world_size=1) def prepare_model(self, device_ids=None) -> Dsnas: - model = ToyDiffModule().to(self.device) - mutator = DiffModuleMutator().to(self.device) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + model = ToyDiffModule() + mutator = DiffModuleMutator() mutator.prepare_from_supernet(model) - algo = Dsnas(model, mutator) + algo = Dsnas(model, mutator).to(self.device) return DsnasDDP( module=algo, find_unused_parameters=True, device_ids=device_ids) @@ -199,24 +197,19 @@ def test_init(self) -> None: @patch('mmengine.logging.message_hub.MessageHub.get_info') def test_dsnasddp_train_step(self, mock_get_info) -> None: - model = ToyDiffModule() - mutator = DiffModuleMutator() - mutator.prepare_from_supernet(model) + ddp_model = self.prepare_model() mock_get_info.return_value = 2 - algo = Dsnas(model, mutator) - ddp_model = DsnasDDP(module=algo, find_unused_parameters=True) data = self._prepare_fake_data() optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG) loss = ddp_model.train_step(data, optim_wrapper) self.assertIsNotNone(loss) - algo = Dsnas(model, mutator) - ddp_model = DsnasDDP(module=algo, find_unused_parameters=True) + ddp_model = self.prepare_model() optim_wrapper_dict = OptimWrapperDict( - architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)), - mutator=OptimWrapper(SGD(model.parameters(), lr=0.01))) + architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)), + mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01))) loss = ddp_model.train_step(data, optim_wrapper_dict) self.assertIsNotNone(loss) diff --git a/tests/test_models/test_losses/test_distillation_losses.py b/tests/test_models/test_losses/test_distillation_losses.py index 4328f7865..37fea2baf 100644 --- a/tests/test_models/test_losses/test_distillation_losses.py +++ b/tests/test_models/test_losses/test_distillation_losses.py @@ -2,7 +2,7 @@ from unittest import TestCase import torch -from mmengine.data import BaseDataElement +from mmengine.structures import BaseDataElement from mmrazor import digit_version from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss,