Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Analyzer] fix analyzer tests #3197

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions tests/test_analyzer/test_fx/test_bias_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from packaging import version
from torch.utils.checkpoint import checkpoint

from colossalai.testing.utils import parameterize

try:
from colossalai._analyzer.fx import symbolic_trace
except:
Expand Down Expand Up @@ -56,9 +58,13 @@ def __init__(self, bias) -> None:
self.linear = LinearModel(3, 3, bias)
self.conv = ConvModel(3, 6, 3, bias)

def forward(self, x, select=0):
def forward(self, x, select=torch.Tensor([0])):
x = self.linear(x)
x = checkpoint(self.conv, x, select)
if select:
x = checkpoint(self.conv, x, 0)
else:
x = checkpoint(self.conv, x, 1)

return x


Expand All @@ -75,10 +81,10 @@ def forward(self, x):


@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@pytest.mark.parametrize("select", [0, 1])
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
def test_siu_model(bias, bias_addition_split, shape, select):
model = SiuModel(bias=bias)
x = torch.rand(shape)
Expand All @@ -87,18 +93,18 @@ def test_siu_model(bias, bias_addition_split, shape, select):
concrete_args={'select': select},
trace_act_ckpt=True,
bias_addition_split=bias_addition_split)
assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!'
assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!'
if bias and bias_addition_split:
assert '+' in gm.code, 'bias addition should be split!'
else:
assert '+' not in gm.code, 'bias addition should not be split!'


@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("alpha", [1, 2])
@pytest.mark.parametrize("beta", [1, 2])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3), (5, 5)])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize("alpha", [1, 2])
@parameterize("beta", [1, 2])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3), (5, 5)])
def test_addmm_model(alpha, beta, bias_addition_split, shape):
model = AddmmModel(alpha=alpha, beta=beta)
x = torch.rand(shape)
Expand All @@ -111,4 +117,5 @@ def test_addmm_model(alpha, beta, bias_addition_split, shape):


if __name__ == '__main__':
test_siu_model(True, True, (3, 3, 3))
test_siu_model()
test_addmm_model()
30 changes: 16 additions & 14 deletions tests/test_analyzer/test_fx/test_shape_prop.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm
from .zoo import tm_models, tmm_models
from packaging import version

from colossalai.testing.utils import parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models

try:
from colossalai._analyzer._subclasses import MetaTensorMode
from colossalai._analyzer.fx import symbolic_trace
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.symbolic_profile import register_shape_impl



@register_shape_impl(torch.nn.functional.linear)
def linear_impl(*args, **kwargs):
assert True
Expand All @@ -23,15 +24,15 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
if node.op in [
# 'call_module', # can apply to params
# 'call_function', # can apply to params
# 'call_method', # can apply to params
'call_module', # can apply to params
'call_function', # can apply to params
'call_method', # can apply to params
]:
assert node.meta['info'].inputs, f'In {gm.__class__.__name__}, {node} has no input shape.'
assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.'


@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tm_models)
def test_torchvision_shape_prop(m):
with MetaTensorMode():
model = m()
Expand All @@ -44,20 +45,21 @@ def test_torchvision_shape_prop(m):
_check_gm_validity(gm)


@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tmm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tmm_models)
def test_timm_shape_prop(m):
with MetaTensorMode():
model = m()
data = torch.rand(100, 3, 224, 224)
meta_args = {
"x": data,
}

gm = symbolic_trace(model, meta_args=meta_args)
shape_prop_pass(gm, data)
_check_gm_validity(gm)


if __name__ == "__main__":
test_torchvision_shape_prop(tm.resnet18)
test_timm_shape_prop(tmm.vgg11)
test_torchvision_shape_prop()
test_timm_shape_prop()
18 changes: 10 additions & 8 deletions tests/test_analyzer/test_fx/test_symbolic_profile.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm
from .zoo import tm_models, tmm_models
from packaging import version

from colossalai.testing.utils import parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models

try:
from colossalai._analyzer._subclasses import MetaTensorMode
Expand All @@ -16,8 +18,8 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'


@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tm_models)
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
model = m()
Expand All @@ -30,8 +32,8 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
_check_gm_validity(gm)


@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tmm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tmm_models)
def test_timm_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
model = m()
Expand All @@ -45,5 +47,5 @@ def test_timm_profile(m, verbose=False, bias_addition_split=False):


if __name__ == "__main__":
test_torchvision_profile(tm.vit_b_16, verbose=True, bias_addition_split=False)
test_timm_profile(tmm.gmlp_b16_224, verbose=True, bias_addition_split=False)
test_torchvision_profile()
test_timm_profile()
8 changes: 4 additions & 4 deletions tests/test_analyzer/test_fx/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@
tmm.dm_nfnet_f0,
tmm.eca_nfnet_l0,
tmm.efficientformer_l1,
tmm.ese_vovnet19b_dw,
# tmm.ese_vovnet19b_dw,
tmm.gmixer_12_224,
tmm.gmlp_b16_224,
tmm.hardcorenas_a,
# tmm.hardcorenas_a,
tmm.hrnet_w18_small,
tmm.inception_v3,
tmm.mixer_b16_224,
tmm.nf_ecaresnet101,
tmm.nf_regnet_b0,
# tmm.pit_b_224, # pretrained only
tmm.regnetv_040,
tmm.skresnet18,
# tmm.regnetv_040,
# tmm.skresnet18,
# tmm.swin_base_patch4_window7_224, # fx bad case
# tmm.tnt_b_patch16_224, # bad case
tmm.vgg11,
Expand Down
11 changes: 6 additions & 5 deletions tests/test_analyzer/test_subclasses/test_flop_tensor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tm
from .zoo import tm_models, tmm_models
from packaging import version

from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models

try:
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
except:
pass


@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
def test_flop_count_module(m):
x = torch.rand(2, 3, 224, 224)
Expand All @@ -37,7 +38,7 @@ def test_flop_count_module(m):
]


@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
Expand All @@ -46,5 +47,5 @@ def test_flop_count_function(func, args, kwargs):


if __name__ == '__main__':
test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224))
test_flop_count_module(tm.resnet18)
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
7 changes: 4 additions & 3 deletions tests/test_analyzer/test_subclasses/test_meta_mode.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest
import torch
import torch.distributed as dist
import torchvision.models as tm
from packaging import version

try:
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
except:
pass
from .zoo import tm_models, tmm_models
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models


def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
Expand All @@ -28,7 +29,7 @@ def run_and_compare(model):
compare_all(x.grad, meta_x.grad)


@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
def test_meta_mode_shape(m):
run_and_compare(m())
Expand Down
53 changes: 0 additions & 53 deletions tests/test_analyzer/test_subclasses/zoo.py

This file was deleted.