Skip to content

Commit

Permalink
fix torch1.13 ci
Browse files Browse the repository at this point in the history
  • Loading branch information
humu789 committed Jan 16, 2023
1 parent c5f54c1 commit 68caf38
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions tests/test_models/test_task_modules/test_custom_tracer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch.nn as nn
import pytest
import torch
from mmcls.models.backbones.resnet import ResLayer
from mmengine.config import Config
from mmengine.registry import MODELS
from torch.fx import GraphModule
from torch.fx._symbolic_trace import Graph

try:
from torch.fx import GraphModule
from torch.fx._symbolic_trace import Graph
except ImportError:
from mmrazor.utils import get_placeholder
GraphModule = get_placeholder('torch>=1.13')
Graph = get_placeholder('torch>=1.13')

from mmrazor import digit_version
from mmrazor.models.task_modules import (CustomTracer, UntracedMethodRegistry,
build_graphmodule,
custom_symbolic_trace)
from mmrazor.models.task_modules.tracer.fx.custom_tracer import \
_prepare_module_dict


class ToyModel(nn.Module):
class ToyModel(torch.nn.Module):

def __init__(self):
super.__init__()
Expand All @@ -35,13 +43,19 @@ def forward(self, x):
class testUntracedMethodRgistry(TestCase):

def test_init(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')

method = ToyModel.get_loss
method_registry = UntracedMethodRegistry(method)
assert hasattr(method_registry, 'method')
assert hasattr(method_registry, 'method_dict')
assert len(method_registry.method_dict) == 0

def test_registry_method(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')

model = ToyModel
method = ToyModel.get_loss
method_registry = UntracedMethodRegistry(method)
Expand All @@ -63,6 +77,9 @@ def setUp(self):
self.skipped_module_classes = [ResLayer]

def test_init(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')

# init without skipped_methods
tracer = CustomTracer()
assert hasattr(tracer, 'skipped_methods')
Expand All @@ -84,6 +101,9 @@ def test_init(self):
CustomTracer(skipped_methods='_get_loss')

def test_trace(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')

# test trace with skipped_methods
model = MODELS.build(self.cfg.model)
UntracedMethodRegistry.method_dict = dict()
Expand Down Expand Up @@ -129,6 +149,9 @@ def test_trace(self):
assert skip_flag


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_custom_symbolic_trace():
cfg = Config.fromfile(
'tests/data/test_models/test_task_modules/mmcls_cfg.py')
Expand All @@ -139,6 +162,9 @@ def test_custom_symbolic_trace():
assert isinstance(graph_module, GraphModule)


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_build_graphmodule():
skipped_methods = ['mmcls.models.heads.ClsHead._get_predictions']
cfg = Config.fromfile(
Expand All @@ -154,5 +180,5 @@ def test_build_graphmodule():
modules = dict(model.named_modules())
module_dict = _prepare_module_dict(model, graph_predict)
for k, v in module_dict.items():
assert isinstance(v, nn.Module)
assert isinstance(v, torch.nn.Module)
assert not isinstance(v, modules[k].__class__)

0 comments on commit 68caf38

Please sign in to comment.