Skip to content

Commit

Permalink
[unittest] supported condititonal testing based on env var (#1701)
Browse files Browse the repository at this point in the history
polish code
  • Loading branch information
FrankLeeeee committed Oct 13, 2022
1 parent 8283e95 commit 0e52f3d
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 10 deletions.
17 changes: 17 additions & 0 deletions colossalai/testing/pytest_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import os


def run_on_environment_flag(name: str):
"""
Conditionally run a test based on the environment variable. If this environment variable is set
to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0.
"""
assert isinstance(name, str)
flag = os.environ.get(name.upper(), '0')

reason = f'Environment varialbe {name} is {flag}'
if flag == '1':
return pytest.mark.skipif(False, reason=reason)
else:
return pytest.mark.skipif(True, reason=reason)
3 changes: 2 additions & 1 deletion colossalai/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ def test_something():
"""

def _wrap_func(f):

def _execute_by_gpu_num(*args, **kwargs):
num_avail_gpu = torch.cuda.device_count()
if num_avail_gpu >= min_gpus:
f(*args, **kwargs)

return _execute_by_gpu_num

return _wrap_func

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag


class ConvModel(nn.Module):
Expand All @@ -22,7 +23,7 @@ def forward(self, condition, x, y):
return output


@pytest.mark.skip("temporarily skipped")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_where_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag


class ConvModel(nn.Module):
Expand Down Expand Up @@ -72,7 +73,7 @@ def check_apply(rank, world_size, port):
assert output.equal(origin_output)


@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag


class ConvModel(nn.Module):
Expand All @@ -33,7 +34,7 @@ def forward(self, x):
return x


@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_solver():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag

BATCH_SIZE = 8
SEQ_LENGHT = 8


@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag


class MLP(torch.nn.Module):
Expand All @@ -34,7 +35,7 @@ def forward(self, x):
return x


@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag


class BMMTensorMethodModule(nn.Module):
Expand All @@ -19,7 +20,7 @@ def forward(self, x1, x2):
return torch.bmm(x1, x2)


@pytest.mark.skip
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module):

Expand Down Expand Up @@ -90,7 +91,7 @@ def test_2d_device_mesh(module):
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list


@pytest.mark.skip
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
model = module()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
import pytest
from colossalai.testing.pytest_wrapper import run_on_environment_flag


@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag


@pytest.mark.skip("for higher testing speed")
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
Expand Down

0 comments on commit 0e52f3d

Please sign in to comment.