Skip to content

Commit

Permalink
swap fx tests with corresponding non-fx tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare committed Nov 25, 2021
1 parent a2af974 commit 1dd8969
Showing 1 changed file with 100 additions and 118 deletions.
218 changes: 100 additions & 118 deletions tests/test_models.py
Expand Up @@ -76,39 +76,85 @@ def _get_input_size(model=None, model_name='', target=None):
return input_size


def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})

eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]

fx_model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
return fx_model



@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")

model = create_model(model_name, pretrained=False)
model.eval()

input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)

model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)

assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'


@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")

input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")

model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")

inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
Expand Down Expand Up @@ -259,12 +305,23 @@ def test_model_features_pretrained(model_name, batch_size):
]


# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
]

@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_torchscript(model_name, batch_size):
"""Run a single forward pass with each model"""
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")

input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
Expand All @@ -273,8 +330,11 @@ def test_model_forward_torchscript(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()

model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))
model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)

assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
Expand Down Expand Up @@ -310,103 +370,39 @@ def test_model_forward_features(model_name, batch_size):
assert not torch.isnan(o).any()


def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})

eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]

fx_model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
return fx_model


EXCLUDE_FX_FILTERS = []
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
# 'beit_large*',
# 'mixer_l*',
# '*nfnet_f2*',
# '*resnext101_32x32d',
# 'resnetv2_152x2*',
# 'resmlp_big*',
# 'resnetrs270',
# 'swin_large*',
# 'vgg*',
# 'vit_large*',
# 'vit_base_patch8*',
# 'xcit_large*',
]


@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")

def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()

input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)

model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)

assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'


@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")

input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")

model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")
model.train()

model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
Expand All @@ -419,24 +415,12 @@ def test_model_backward_fx(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'


# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS


@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")

def test_model_forward_torchscript(model_name, batch_size):
"""Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
Expand All @@ -445,11 +429,9 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()

model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))

assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

0 comments on commit 1dd8969

Please sign in to comment.