From daebfc1ae585f9fe37de77a001a0a101ff8520e9 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 2 Oct 2025 08:26:59 +0000 Subject: [PATCH 1/2] [ModelSuite] Add Toy Models Summary: Here we introduce model suite (model.py). The idea here to start and codify the ideas from jiannanWang/BackendBenchExamples. Specifically this PR adds some example models / configs which are to be loaded + a Readme. (It may be useful to look at the PR above this as well since it's the model loading logic). This PR adds two toy models to model suite SmokeTestModel - This is simple model that uses aten.ops.mm as we can implement a correct version of this op ToyCoreOpsModel - This is a model which explicitly calls the backwards passes which are both in torchbench + core. Test Plan: the test infra is in the pr above, so tests passing on the PR above should be sufficient here ### Future work with Model Suite https://github.com/meta-pytorch/BackendBench/issues/181 --- BackendBench/suite/models/README.md | 80 +++++++++++++++++ .../models/SmokeTestModel/SmokeTestModel.json | 25 ++++++ .../models/SmokeTestModel/SmokeTestModel.py | 68 +++++++++++++++ .../ToyCoreOpsModel/ToyCoreOpsModel.json | 34 ++++++++ .../models/ToyCoreOpsModel/ToyCoreOpsModel.py | 87 +++++++++++++++++++ 5 files changed, 294 insertions(+) create mode 100644 BackendBench/suite/models/README.md create mode 100644 BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json create mode 100644 BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py create mode 100644 BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json create mode 100644 BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py diff --git a/BackendBench/suite/models/README.md b/BackendBench/suite/models/README.md new file mode 100644 index 00000000..57e707dc --- /dev/null +++ b/BackendBench/suite/models/README.md @@ -0,0 +1,80 @@ +# Adding Models to BackendBench + +## Quick Start + +Models define operator lists and validate that custom backends work correctly in full model execution. Two files required: + +``` +BackendBench/suite/models/YourModel/ +├── YourModel.py # nn.Module class +└── YourModel.json # Configuration +``` + +**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive) + +## Adding a Model + +### 1. Create Directory and Files + +```bash +cd BackendBench/suite/models +mkdir MyModel +cd MyModel +touch MyModel.py MyModel.json +``` + +### 2. Write Model Class (`MyModel.py`) + +**Requirements:** +- Class name = filename (exact match) +- All `__init__` params need defaults +- Add a main() / runner if you are inclined for sanity checking +- Keep it simple - focus on specific operators you're testing +- Look in this directory for examples + +### 3. Write Config (`MyModel.json`) + +**Key Fields:** +- `model_config.init_args` - Args for `__init__()`, must match your defaults +- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten..default"`) +- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench) + - Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc. +- `metadata.description` - What this model tests +- Look in this directory for examples + +**Finding operator names:** +```python +from torch.profiler import profile, ProfilerActivity + +with profile(activities=[ProfilerActivity.CPU]) as prof: + output = model(x) + loss = output.sum() + loss.backward() + +for event in prof.key_averages(): + if "aten::" in event.key: + print(event.key) +``` + +### 4. Test Your Model + +```bash +# Test standalone +cd BackendBench/suite/models/MyModel +python MyModel.py # Add main() for standalone testing + +# Test with suite +python -m BackendBench.scripts.main \ + --suite model \ + --backend aten \ + --model-filter MyModel + +# Expected output: +# Model: MyModel +# Status: ✓ Passed (2/2 tests) +# ✓ small +# ✓ large +``` + +### 5: Validation +`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly. diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json new file mode 100644 index 00000000..b7d286ae --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json @@ -0,0 +1,25 @@ +{ + "model_config": { + "init_args": { + "input_dim": 128, + "hidden_dim": 128, + "output_dim": 128 + } + }, + "ops": { + "forward": [ + "aten.mm.default" + ], + "backward": [ + "aten.mm.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 128], f32)})", + "medium_batch": "([], {'x': T([16, 128], f32)})", + "large_batch": "([], {'x': T([32, 128], f32)})" + }, + "metadata": { + "description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes" + } +} diff --git a/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py new file mode 100644 index 00000000..3bf627e4 --- /dev/null +++ b/BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple model that tests matrix multiplication operations using explicit +torch.mm calls. +""" + +import torch +import torch.nn as nn + + +class SmokeTestModel(nn.Module): + """ + Model that uses explicit torch.mm operations to test aten.mm.default + in forward/backward. + """ + + def __init__( + self, + input_dim: int = 128, + hidden_dim: int = 128, + output_dim: int = 128, + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + self.weight1 = nn.Parameter(torch.randn(input_dim, hidden_dim)) + self.weight2 = nn.Parameter(torch.randn(hidden_dim, output_dim)) + self.bias1 = nn.Parameter(torch.randn(hidden_dim)) + self.bias2 = nn.Parameter(torch.randn(output_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass: (x @ weight1 + bias1) -> relu -> (x @ weight2 + bias2) + """ + x = torch.mm(x, self.weight1) + self.bias1 + x = torch.relu(x) + x = torch.mm(x, self.weight2) + self.bias2 + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = SmokeTestModel(input_dim=128, hidden_dim=128, output_dim=128) + batch_size = 4 + input_tensor = torch.randn(batch_size, 128, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + + +if __name__ == "__main__": + main() diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json new file mode 100644 index 00000000..1586273e --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.json @@ -0,0 +1,34 @@ +{ + "model_config": { + "init_args": { + "in_channels": 3, + "hidden_channels": 32, + "out_channels": 8, + "num_groups": 8 + } + }, + "ops": { + "forward": [ + "aten.convolution.default", + "aten.native_group_norm.default", + "aten.max_pool2d_with_indices.default", + "aten.avg_pool2d.default", + "aten._adaptive_avg_pool2d.default" + ], + "backward": [ + "aten.convolution_backward.default", + "aten.native_group_norm_backward.default", + "aten.max_pool2d_with_indices_backward.default", + "aten.avg_pool2d_backward.default", + "aten._adaptive_avg_pool2d_backward.default" + ] + }, + "model_tests": { + "small_batch": "([], {'x': T([2, 3, 32, 32], f32)})", + "medium_batch": "([], {'x': T([4, 3, 64, 64], f32)})", + "large_input": "([], {'x': T([2, 3, 128, 128], f32)})" + }, + "metadata": { + "description": "Core operations model testing fundamental operators: convolution, group norm, max pool with indices, avg pool, adaptive avg pool" + } +} \ No newline at end of file diff --git a/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py new file mode 100644 index 00000000..410e4c4f --- /dev/null +++ b/BackendBench/suite/models/ToyCoreOpsModel/ToyCoreOpsModel.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +CNN model that triggers core PyTorch backward operators: +- convolution_backward +- native_group_norm_backward +- max_pool2d_with_indices_backward +- avg_pool2d_backward +- _adaptive_avg_pool2d_backward +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ToyCoreOpsModel(nn.Module): + """CNN that uses conv, group norm, max pool, avg pool, and adaptive avg pool.""" + + def __init__( + self, + in_channels: int = 3, + hidden_channels: int = 32, + out_channels: int = 8, + num_groups: int = 8, + ): + super().__init__() + + if hidden_channels % num_groups != 0: + raise ValueError( + f"hidden_channels ({hidden_channels}) must be divisible by " + f"num_groups ({num_groups})" + ) + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.num_groups = num_groups + + self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm1 = nn.GroupNorm(num_groups, hidden_channels) + self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1) + self.group_norm2 = nn.GroupNorm(num_groups, hidden_channels) + self.conv_out = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through: Conv->GroupNorm->ReLU->MaxPool->Conv-> + GroupNorm->ReLU->AvgPool->AdaptiveAvgPool->Conv + Output is always (batch, out_channels, 4, 4) regardless of + input size. + """ + x = F.relu(self.group_norm1(self.conv1(x))) + x, _ = F.max_pool2d(x, kernel_size=2, return_indices=True) + x = F.relu(self.group_norm2(self.conv2(x))) + x = F.avg_pool2d(x, kernel_size=2) + x = F.adaptive_avg_pool2d(x, output_size=(4, 4)) + x = self.conv_out(x) + return x + + +def main(): + """Demonstrate the model with a forward/backward pass.""" + model = ToyCoreOpsModel(in_channels=3, hidden_channels=32, out_channels=8, num_groups=8) + batch_size = 2 + input_tensor = torch.randn(batch_size, 3, 64, 64, requires_grad=True) + + model.train() + output = model(input_tensor) + loss = output.sum() + loss.backward() + + print("✓ Forward/backward pass completed") + print(f" Parameters: {sum(p.numel() for p in model.parameters())}") + print(f" Input: {input_tensor.shape} -> Output: {output.shape}") + grad_count = sum(1 for p in model.parameters() if p.grad is not None) + total_params = len(list(model.parameters())) + print(f" Gradients computed: {grad_count}/{total_params}") + return model + + +if __name__ == "__main__": + main() From 6e6334b6bd3d7d2cb3c0db57e828588b9a55f222 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 2 Oct 2025 08:27:51 +0000 Subject: [PATCH 2/2] [ModelSuite] Add model loading infrastructure ### Model Registration This PR creates a way of adding models to the suite and automatically validates them through CI. It also loads the models as well. The way these models are added is detailed in this readme. The tl;dir is we use a format similar to kernelbench and SakanaAI/robust-kbench where we pair model code with a config. Importantly the configs contain initialization code, forward pass arguments (both in a similar format to torchbench), and a list of ops in the forward and backwards passes. These ops are fairly important as they are what we want to point out to the researcher when they are optimizing a model. There is a README.md to help folks setup proper model code / configs. We also further verify these registrations are correct through CI. Specifically we run test/test_model_ops_configs.py to ensure the configs are formatted correctly. ### Small Things - Added a --model-filter to the CLI as it will be needed to support filtering in model suite as it chooses things to test based on the model not set of ops ### Testing New tests are added so pytest resolves things here ### Future work with Model Suite https://github.com/meta-pytorch/BackendBench/issues/181 --- BackendBench/scripts/main.py | 26 +++- BackendBench/suite/__init__.py | 2 + BackendBench/suite/model.py | 112 +++++++++++++++++ test/test_model_ops_configs.py | 221 +++++++++++++++++++++++++++++++++ test/test_model_suite.py | 53 ++++++++ 5 files changed, 412 insertions(+), 2 deletions(-) create mode 100644 BackendBench/suite/model.py create mode 100644 test/test_model_ops_configs.py create mode 100644 test/test_model_suite.py diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 479e5805..2240e2d1 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -19,6 +19,7 @@ from BackendBench.output import save_results from BackendBench.suite import ( FactoTestSuite, + ModelSuite, OpInfoTestSuite, SmokeTestSuite, TorchBenchTestSuite, @@ -50,7 +51,7 @@ def setup_logging(log_level): @click.option( "--suite", default="smoke", - type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]), + type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]), help="Which suite to run", ) @click.option( @@ -63,7 +64,13 @@ def setup_logging(log_level): "--ops", default=None, type=str, - help="Comma-separated list of ops to run", + help="Comma-separated list of ops to run (not supported for model suite)", +) +@click.option( + "--model-filter", + default=None, + type=str, + help="Comma-separated list of models to run (only for model suite)", ) @click.option( "--topn-inputs", @@ -147,6 +154,7 @@ def cli( suite, backend, ops, + model_filter, topn_inputs, llm_attempts, llm_model, @@ -166,9 +174,22 @@ def cli( if check_overhead_dominated_ops: raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite") + if suite == "model": + if ops is not None: + raise ValueError( + "--ops filter is not supported for model suite. Use --model-filter instead" + ) + # remove this in later PR as model suite is supported + raise NotImplementedError("Model suite is not supported yet") + + if suite != "model" and model_filter is not None: + raise ValueError("--model-filter is only supported for model suite") + setup_logging(log_level) if ops: ops = ops.split(",") + if model_filter: + model_filter = model_filter.split(",") suite = { "smoke": lambda: SmokeTestSuite, @@ -191,6 +212,7 @@ def cli( torch.bfloat16, filter=ops, ), + "model": lambda: ModelSuite(filter=model_filter), }[suite]() backend_name = backend diff --git a/BackendBench/suite/__init__.py b/BackendBench/suite/__init__.py index 410a5d6e..e9c332a9 100644 --- a/BackendBench/suite/__init__.py +++ b/BackendBench/suite/__init__.py @@ -15,6 +15,7 @@ from .base import OpTest, Test, TestSuite from .facto import FactoTestSuite +from .model import ModelSuite from .opinfo import OpInfoTestSuite from .smoke import randn, SmokeTestSuite from .torchbench import TorchBenchOpTest, TorchBenchTestSuite @@ -24,6 +25,7 @@ "OpTest", "TestSuite", "FactoTestSuite", + "ModelSuite", "OpInfoTestSuite", "SmokeTestSuite", "randn", diff --git a/BackendBench/suite/model.py b/BackendBench/suite/model.py new file mode 100644 index 00000000..12f98256 --- /dev/null +++ b/BackendBench/suite/model.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Model Suite for testing models defined in configs. +""" + +import importlib.util +import json +import logging +import os +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +def load_models( + models_dir: str = "models", filter: Optional[List[str]] = None +) -> List[Dict[str, Any]]: + """Load models using strict naming convention: folder_name/folder_name.py + folder_name.json + + Args: + models_dir: Directory containing models (default: "models") + filter: Optional list of model names to load. If None, loads all models. + + Returns: + List of dictionaries with keys: + - name: Model name (str) + - class: Model class (type) + - config: Configuration dictionary from JSON file + """ + models = [] + + if not os.path.exists(models_dir): + raise FileNotFoundError(f"Models directory not found: {models_dir}") + + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + # Skip if not in filter + if filter is not None and model_name not in filter: + continue + + # Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json + model_file = os.path.join(model_dir, f"{model_name}.py") + config_file = os.path.join(model_dir, f"{model_name}.json") + + # Check both files exist + if not os.path.exists(model_file): + raise FileNotFoundError(f"Model file not found: {model_file}") + + if not os.path.exists(config_file): + raise FileNotFoundError(f"Config file not found: {config_file}") + + try: + # Load config + with open(config_file, "r") as f: + config = json.load(f) + + # Load model class dynamically + spec = importlib.util.spec_from_file_location(model_name, model_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find model class (must match model_name exactly) + if not hasattr(module, model_name): + raise RuntimeError(f"Model class '{model_name}' not found in {model_file}") + + model_class = getattr(module, model_name) + if not (isinstance(model_class, type) and hasattr(model_class, "forward")): + raise RuntimeError(f"'{model_name}' in {model_file} is not a valid model class") + + models.append({"name": model_name, "class": model_class, "config": config}) + logger.info(f"Loaded model: {model_name}") + + except Exception as e: + raise RuntimeError(f"Failed to load model {model_name}: {e}") + + if filter is not None and len(models) == 0: + raise ValueError(f"No models found matching filter: {filter}") + + return models + + +class ModelSuite: + """Model Suite for end-to-end model testing.""" + + def __init__( + self, + name: str = "model", + filter: Optional[List[str]] = None, + ): + """Initialize ModelSuite. + + Args: + name: Suite name (default: "model") + filter: Optional list of model names to load + """ + models_dir = os.path.join(os.path.dirname(__file__), "models") + + # Load models + models = load_models(models_dir=models_dir, filter=filter) + logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}") + + # Store loaded models + self.models = models + self.name = name diff --git a/test/test_model_ops_configs.py b/test/test_model_ops_configs.py new file mode 100644 index 00000000..8b3f3dbe --- /dev/null +++ b/test/test_model_ops_configs.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit test to verify that ModelSuite's operator filter correctly matches +the operators defined in model configs. + +This test validates that: +1. load_models correctly loads model configs from the models directory +2. load_model_ops extracts the correct set of operators from model configs +3. TorchBenchTestSuite initialized with those operators has matching optests +4. JSON config files have proper format with required fields +""" + +import json +import os +import unittest +from typing import Any, Dict, List, Set + +from BackendBench.suite.model import load_models +from BackendBench.suite.torchbench import TorchBenchTestSuite + + +def load_model_ops(models: List[Dict[str, Any]]) -> Set[str]: + """Extract unique set of operators from model configs. + + Args: + models: List of model dictionaries with 'name', 'class', and 'config' keys + + Returns: + Set of operator names defined across all model configs + """ + model_ops = set() + for model in models: + config_ops = model["config"].get("ops") + if not config_ops: + raise ValueError(f"Model {model['name']} has no 'ops' field in config") + assert "forward" in config_ops, f"Model {model['name']} has no 'forward' field in config" + assert "backward" in config_ops, f"Model {model['name']} has no 'backward' field in config" + ops_list = config_ops["forward"] + config_ops["backward"] + + model_ops.update(ops_list) + return model_ops + + +class TestModelOpsConfigs(unittest.TestCase): + """Test that model ops filter correctly initializes TorchBenchTestSuite.""" + + def test_model_ops_match_suite_optests(self): + """Test that suite's optests match the operators from model configs.""" + # Get the models directory path (same as ModelSuite does) + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load models using load_models + models = load_models(models_dir=models_dir) + + # Verify we loaded at least one model + self.assertGreater(len(models), 0, "Should load at least one model") + + # Extract operators from model configs using load_model_ops + model_ops = load_model_ops(models) + + # Verify we have operators + self.assertGreater(len(model_ops), 0, "Should have at least one operator") + + # Create filter list from model ops + ops_filter = list(model_ops) + + # Initialize TorchBenchTestSuite with the filter + suite = TorchBenchTestSuite( + name="test_model_ops", + filename=None, # Use default HuggingFace dataset + filter=ops_filter, + topn=None, + ) + + # Get the set of operators in the suite's optests + suite_ops = set(suite.optests.keys()) + + # The suite_ops should be a subset of model_ops because: + # - model_ops is the filter we requested + # - suite_ops contains only those operators that exist in the TorchBench dataset + # - Not all operators in model configs may be in the dataset + self.assertTrue( + suite_ops.issubset(model_ops), + f"Suite operators {suite_ops} should be subset of model operators {model_ops}", + ) + + # Verify that suite actually has some operators + self.assertGreater( + len(suite_ops), 0, "Suite should contain at least one operator from model configs" + ) + + def test_json_configs_have_required_fields(self): + """Test that all JSON config files have proper format with required fields.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Load all models + models = load_models(models_dir=models_dir) + + for model in models: + model_name = model["name"] + config = model["config"] + + # Check required top-level fields + self.assertIn("ops", config, f"Model {model_name}: config must have 'ops' field") + self.assertIn( + "model_tests", config, f"Model {model_name}: config must have 'model_tests' field" + ) + + # Validate 'ops' field - can be list or dict + config_ops = config["ops"] + self.assertGreater( + len(config_ops["forward"] + config_ops["backward"]), + 0, + f"Model {model_name}: 'ops' list must not be empty", + ) + for op in config_ops["forward"] + config_ops["backward"]: + self.assertIsInstance( + op, str, f"Model {model_name}: each op in 'ops' must be a string" + ) + self.assertIsInstance( + config_ops["forward"], + list, + f"Model {model_name}: 'ops.forward' must be a list", + ) + for op in config_ops["forward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.forward' must be a string", + ) + self.assertIsInstance( + config_ops["backward"], + list, + f"Model {model_name}: 'ops.backward' must be a list", + ) + for op in config_ops["backward"]: + self.assertIsInstance( + op, + str, + f"Model {model_name}: each op in 'ops.backward' must be a string", + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + + # Validate 'model_tests' field + self.assertIsInstance( + config["model_tests"], + dict, + f"Model {model_name}: 'model_tests' must be a dictionary", + ) + self.assertGreater( + len(config["model_tests"]), + 0, + f"Model {model_name}: 'model_tests' must not be empty", + ) + for test_name, test_args in config["model_tests"].items(): + self.assertIsInstance( + test_name, str, f"Model {model_name}: test names must be strings" + ) + self.assertIsInstance( + test_args, str, f"Model {model_name}: test args must be strings" + ) + + # Check optional but recommended fields + if "model_config" in config: + self.assertIsInstance( + config["model_config"], + dict, + f"Model {model_name}: 'model_config' must be a dictionary if present", + ) + + def test_json_files_are_valid_json(self): + """Test that all JSON config files are valid JSON and can be parsed.""" + models_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "BackendBench", "suite", "models" + ) + + # Find all JSON files in the models directory + for model_name in os.listdir(models_dir): + model_dir = os.path.join(models_dir, model_name) + if not os.path.isdir(model_dir): + continue + + json_file = os.path.join(model_dir, f"{model_name}.json") + if not os.path.exists(json_file): + continue + + # Try to parse the JSON file + with open(json_file, "r") as f: + try: + config = json.load(f) + self.assertIsInstance( + config, + dict, + f"JSON file {json_file} must contain a dictionary at top level", + ) + except json.JSONDecodeError as e: + self.fail(f"JSON file {json_file} is not valid JSON: {e}") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_model_suite.py b/test/test_model_suite.py new file mode 100644 index 00000000..12ddaf1b --- /dev/null +++ b/test/test_model_suite.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for Model Suite: Filtered TorchBench operators from model tracing + +This test suite validates: +1. Model loading from toy_models directory +2. Operator extraction via model tracing +3. ModelSuite creates filtered TorchBench suite +""" + +import logging +import unittest + +from BackendBench.suite.model import load_models + +# Setup logging +logging.basicConfig(level=logging.WARNING) + + +class TestModelLoading(unittest.TestCase): + """Test model loading functionality.""" + + def test_load_models(self): + """Test that models can be loaded from directory.""" + models = load_models(models_dir="BackendBench/suite/models") + self.assertGreater(len(models), 0, "Should load at least one model") + + # Verify model structure + for model in models: + self.assertIn("name", model) + self.assertIn("class", model) + self.assertIn("config", model) + + def test_load_specific_model(self): + """Test loading a specific model by name.""" + models = load_models(models_dir="BackendBench/suite/models", filter=["ToyCoreOpsModel"]) + self.assertEqual(len(models), 1) + self.assertEqual(models[0]["name"], "ToyCoreOpsModel") + + def test_invalid_filter(self): + """Test that invalid filter raises error.""" + with self.assertRaises(ValueError): + load_models(models_dir="BackendBench/suite/models", filter=["nonexistent"]) + + +if __name__ == "__main__": + # Run tests + unittest.main(verbosity=2)