Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from BackendBench.output import save_results
from BackendBench.suite import (
FactoTestSuite,
ModelSuite,
OpInfoTestSuite,
SmokeTestSuite,
TorchBenchTestSuite,
Expand Down Expand Up @@ -50,7 +51,7 @@ def setup_logging(log_level):
@click.option(
"--suite",
default="smoke",
type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]),
type=click.Choice(["smoke", "opinfo", "torchbench", "facto", "model"]),
help="Which suite to run",
)
@click.option(
Expand All @@ -63,7 +64,13 @@ def setup_logging(log_level):
"--ops",
default=None,
type=str,
help="Comma-separated list of ops to run",
help="Comma-separated list of ops to run (not supported for model suite)",
)
@click.option(
"--model-filter",
default=None,
type=str,
help="Comma-separated list of models to run (only for model suite)",
)
@click.option(
"--topn-inputs",
Expand Down Expand Up @@ -147,6 +154,7 @@ def cli(
suite,
backend,
ops,
model_filter,
topn_inputs,
llm_attempts,
llm_model,
Expand All @@ -166,9 +174,22 @@ def cli(
if check_overhead_dominated_ops:
raise ValueError("check-overhead-dominated-ops is only supported for torchbench suite")

if suite == "model":
if ops is not None:
raise ValueError(
"--ops filter is not supported for model suite. Use --model-filter instead"
)
# remove this in later PR as model suite is supported
raise NotImplementedError("Model suite is not supported yet")

if suite != "model" and model_filter is not None:
raise ValueError("--model-filter is only supported for model suite")

setup_logging(log_level)
if ops:
ops = ops.split(",")
if model_filter:
model_filter = model_filter.split(",")

suite = {
"smoke": lambda: SmokeTestSuite,
Expand All @@ -191,6 +212,7 @@ def cli(
torch.bfloat16,
filter=ops,
),
"model": lambda: ModelSuite(filter=model_filter),
}[suite]()

backend_name = backend
Expand Down
2 changes: 2 additions & 0 deletions BackendBench/suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .base import OpTest, Test, TestSuite
from .facto import FactoTestSuite
from .model import ModelSuite
from .opinfo import OpInfoTestSuite
from .smoke import randn, SmokeTestSuite
from .torchbench import TorchBenchOpTest, TorchBenchTestSuite
Expand All @@ -24,6 +25,7 @@
"OpTest",
"TestSuite",
"FactoTestSuite",
"ModelSuite",
"OpInfoTestSuite",
"SmokeTestSuite",
"randn",
Expand Down
112 changes: 112 additions & 0 deletions BackendBench/suite/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Model Suite for testing models defined in configs.
"""

import importlib.util
import json
import logging
import os
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)


def load_models(
models_dir: str = "models", filter: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Load models using strict naming convention: folder_name/folder_name.py + folder_name.json

Args:
models_dir: Directory containing models (default: "models")
filter: Optional list of model names to load. If None, loads all models.

Returns:
List of dictionaries with keys:
- name: Model name (str)
- class: Model class (type)
- config: Configuration dictionary from JSON file
"""
models = []

if not os.path.exists(models_dir):
raise FileNotFoundError(f"Models directory not found: {models_dir}")

for model_name in os.listdir(models_dir):
model_dir = os.path.join(models_dir, model_name)
if not os.path.isdir(model_dir):
continue

# Skip if not in filter
if filter is not None and model_name not in filter:
continue

# Strict naming convention: folder_name/folder_name.py and folder_name/folder_name.json
model_file = os.path.join(model_dir, f"{model_name}.py")
config_file = os.path.join(model_dir, f"{model_name}.json")

# Check both files exist
if not os.path.exists(model_file):
raise FileNotFoundError(f"Model file not found: {model_file}")

if not os.path.exists(config_file):
raise FileNotFoundError(f"Config file not found: {config_file}")

try:
# Load config
with open(config_file, "r") as f:
config = json.load(f)

# Load model class dynamically
spec = importlib.util.spec_from_file_location(model_name, model_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

# Find model class (must match model_name exactly)
if not hasattr(module, model_name):
raise RuntimeError(f"Model class '{model_name}' not found in {model_file}")

model_class = getattr(module, model_name)
if not (isinstance(model_class, type) and hasattr(model_class, "forward")):
raise RuntimeError(f"'{model_name}' in {model_file} is not a valid model class")

models.append({"name": model_name, "class": model_class, "config": config})
logger.info(f"Loaded model: {model_name}")

except Exception as e:
raise RuntimeError(f"Failed to load model {model_name}: {e}")

if filter is not None and len(models) == 0:
raise ValueError(f"No models found matching filter: {filter}")

return models


class ModelSuite:
"""Model Suite for end-to-end model testing."""

def __init__(
self,
name: str = "model",
filter: Optional[List[str]] = None,
):
"""Initialize ModelSuite.

Args:
name: Suite name (default: "model")
filter: Optional list of model names to load
"""
models_dir = os.path.join(os.path.dirname(__file__), "models")

# Load models
models = load_models(models_dir=models_dir, filter=filter)
logger.info(f"ModelSuite: Loaded {len(models)} models from {models_dir}")

# Store loaded models
self.models = models
self.name = name
80 changes: 80 additions & 0 deletions BackendBench/suite/models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Adding Models to BackendBench

## Quick Start

Models define operator lists and validate that custom backends work correctly in full model execution. Two files required:

```
BackendBench/suite/models/YourModel/
├── YourModel.py # nn.Module class
└── YourModel.json # Configuration
```

**Naming rule:** Directory name = File name = Class name (exact match, case-sensitive)

## Adding a Model

### 1. Create Directory and Files

```bash
cd BackendBench/suite/models
mkdir MyModel
cd MyModel
touch MyModel.py MyModel.json
```

### 2. Write Model Class (`MyModel.py`)

**Requirements:**
- Class name = filename (exact match)
- All `__init__` params need defaults
- Add a main() / runner if you are inclined for sanity checking
- Keep it simple - focus on specific operators you're testing
- Look in this directory for examples

### 3. Write Config (`MyModel.json`)

**Key Fields:**
- `model_config.init_args` - Args for `__init__()`, must match your defaults
- `ops.forward` / `ops.backward` - Aten operators to test (format: `"aten.<op>.default"`)
- `model_tests` - Test inputs as `"([], {kwarg: T([shape], dtype)})"` The format is further described [here](https://huggingface.co/datasets/GPUMODE/backendbench_tests#serialized-arguments-in-backendbench)
- Supported dtypes: `f32`, `f64`, `i32`, `i64`, `bool`, etc.
- `metadata.description` - What this model tests
- Look in this directory for examples

**Finding operator names:**
```python
from torch.profiler import profile, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU]) as prof:
output = model(x)
loss = output.sum()
loss.backward()

for event in prof.key_averages():
if "aten::" in event.key:
print(event.key)
```

### 4. Test Your Model

```bash
# Test standalone
cd BackendBench/suite/models/MyModel
python MyModel.py # Add main() for standalone testing

# Test with suite
python -m BackendBench.scripts.main \
--suite model \
--backend aten \
--model-filter MyModel

# Expected output:
# Model: MyModel
# Status: ✓ Passed (2/2 tests)
# ✓ small
# ✓ large
```

### 5: Validation
`test/test_model_ops_configs.py` and `test/test_model_ops_coverage.py` are tests that validate that all models are loadable / formatted correctly.
25 changes: 25 additions & 0 deletions BackendBench/suite/models/SmokeTestModel/SmokeTestModel.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"model_config": {
"init_args": {
"input_dim": 128,
"hidden_dim": 128,
"output_dim": 128
}
},
"ops": {
"forward": [
"aten.mm.default"
],
"backward": [
"aten.mm.default"
]
},
"model_tests": {
"small_batch": "([], {'x': T([2, 128], f32)})",
"medium_batch": "([], {'x': T([16, 128], f32)})",
"large_batch": "([], {'x': T([32, 128], f32)})"
},
"metadata": {
"description": "Smoke test model focused on matrix multiplication operations (mm) in forward and backward passes"
}
}
68 changes: 68 additions & 0 deletions BackendBench/suite/models/SmokeTestModel/SmokeTestModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Simple model that tests matrix multiplication operations using explicit
torch.mm calls.
"""

import torch
import torch.nn as nn


class SmokeTestModel(nn.Module):
"""
Model that uses explicit torch.mm operations to test aten.mm.default
in forward/backward.
"""

def __init__(
self,
input_dim: int = 128,
hidden_dim: int = 128,
output_dim: int = 128,
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim

self.weight1 = nn.Parameter(torch.randn(input_dim, hidden_dim))
self.weight2 = nn.Parameter(torch.randn(hidden_dim, output_dim))
self.bias1 = nn.Parameter(torch.randn(hidden_dim))
self.bias2 = nn.Parameter(torch.randn(output_dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass: (x @ weight1 + bias1) -> relu -> (x @ weight2 + bias2)
"""
x = torch.mm(x, self.weight1) + self.bias1
x = torch.relu(x)
x = torch.mm(x, self.weight2) + self.bias2
return x


def main():
"""Demonstrate the model with a forward/backward pass."""
model = SmokeTestModel(input_dim=128, hidden_dim=128, output_dim=128)
batch_size = 4
input_tensor = torch.randn(batch_size, 128, requires_grad=True)

model.train()
output = model(input_tensor)
loss = output.sum()
loss.backward()

print("✓ Forward/backward pass completed")
print(f" Parameters: {sum(p.numel() for p in model.parameters())}")
print(f" Input: {input_tensor.shape} -> Output: {output.shape}")
grad_count = sum(1 for p in model.parameters() if p.grad is not None)
total_params = len(list(model.parameters()))
print(f" Gradients computed: {grad_count}/{total_params}")


if __name__ == "__main__":
main()
Loading