Skip to content

08. Python API Reference

Yueming Hao edited this page Dec 19, 2025 · 1 revision

Python API Reference

This page provides complete API documentation for TritonParse Python modules.

💡 Looking for environment variables? See Environment Variables Reference for all configuration options via environment variables. Environment variables affect the tritonparse.structured_logging module and are automatically applied when you call init().

📋 Module Overview

Module Purpose
tritonparse.structured_logging Initialize trace collection
tritonparse.utils Parse and process trace logs
tritonparse.context_manager Simplified workflow with context manager
tritonparse.reproducer.orchestrator Generate standalone reproducer scripts
tritonparse.info Query kernel information from traces

tritonparse.structured_logging

Core module for initializing trace collection during Triton kernel compilation and execution.

init()

Main initialization function with full control over tracing behavior.

def init(
    trace_folder: Optional[str] = None,
    enable_trace_launch: bool = False,
    enable_more_tensor_information: bool = False,
    enable_sass_dump: Optional[bool] = False,
    enable_tensor_blob_storage: bool = False,
    tensor_storage_quota: Optional[int] = None,
) -> None

Parameters:

Parameter Type Default Description
trace_folder Optional[str] None Directory for storing trace files. If None, uses environment variable TRITON_TRACE or /logs/.
enable_trace_launch bool False Enable launch event tracing to capture runtime parameters.
enable_more_tensor_information bool False Collect tensor statistics (min, max, mean, std).
enable_sass_dump Optional[bool] False Enable NVIDIA SASS assembly dump. Warning: Slows compilation.
enable_tensor_blob_storage bool False Save actual tensor data as blob files.
tensor_storage_quota Optional[int] 100GB Maximum total storage for tensor blobs in bytes.

Example - Basic initialization:

import tritonparse.structured_logging

# Simple initialization
tritonparse.structured_logging.init("./logs/")

Example - Full initialization with all options:

import tritonparse.structured_logging

tritonparse.structured_logging.init(
    trace_folder="./logs/",
    enable_trace_launch=True,
    enable_more_tensor_information=True,
    enable_tensor_blob_storage=True,
    tensor_storage_quota=10 * 1024 * 1024 * 1024,  # 10GB
)

Example - For torch.compile kernels:

import os
# Required for TorchInductor kernel launch tracing
os.environ["TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK"] = "1"

import tritonparse.structured_logging

tritonparse.structured_logging.init(
    trace_folder="./logs/",
    enable_trace_launch=True,
    enable_more_tensor_information=True,
)

💡 Note: Arguments passed to init() take precedence over environment variables.


init_basic()

Minimal initialization that sets up the logging handler without registering the compilation listener.

def init_basic(trace_folder: Optional[str] = None) -> None

Parameters:

Parameter Type Default Description
trace_folder Optional[str] None Directory for storing trace files.

When to use:

  • When you only need the logging infrastructure without automatic compilation tracing
  • For advanced use cases where you manually control event emission

Example:

import tritonparse.structured_logging

tritonparse.structured_logging.init_basic("./logs/")

init_with_env()

Initialize TritonParse using environment variables.

def init_with_env() -> None

Environment variables used:

  • TRITON_TRACE - Trace output directory
  • TRITON_TRACE_LAUNCH - Enable launch tracing ("1" to enable)

Example:

import os
os.environ["TRITON_TRACE"] = "./logs/"
os.environ["TRITON_TRACE_LAUNCH"] = "1"

import tritonparse.structured_logging
tritonparse.structured_logging.init_with_env()

Example - Shell-based configuration:

export TRITON_TRACE="./logs/"
export TRITON_TRACE_LAUNCH="1"
python your_script.py
# In your_script.py
import tritonparse.structured_logging
tritonparse.structured_logging.init_with_env()

clear_logging_config()

Reset all logging configurations to their default state.

def clear_logging_config() -> None

What it resets:

  • Removes trace handler from logger
  • Clears global state variables
  • Resets Triton knobs (compilation listener, launch hooks)
  • Clears tensor blob manager

When to use:

  • Between test cases to ensure isolation
  • When you need to reinitialize with different settings

⚠️ Warning: Use with caution. This function is primarily intended for testing.

Example:

import tritonparse.structured_logging

# Initialize
tritonparse.structured_logging.init("./logs/", enable_trace_launch=True)

# ... run kernels ...

# Reset configuration
tritonparse.structured_logging.clear_logging_config()

# Can reinitialize with different settings
tritonparse.structured_logging.init("./other_logs/")

tritonparse.utils

Utilities for parsing and processing trace logs.

unified_parse()

Main function for parsing raw trace logs into structured format.

def unified_parse(
    source: str,
    out: Optional[str] = None,
    overwrite: Optional[bool] = False,
    rank: Optional[int] = None,
    all_ranks: bool = False,
    verbose: bool = False,
    split_inductor_compilations: bool = True,
    skip_logger: bool = False,
    **kwargs,
) -> Optional[str]

Parameters:

Parameter Type Default Description
source str Required Input directory containing raw log files, or path to a single log file.
out Optional[str] None Output directory for parsed files. If None, uses a temporary directory.
overwrite Optional[bool] False Delete existing output directory if it exists.
rank Optional[int] None Specific rank to analyze (for multi-GPU traces).
all_ranks bool False Analyze all ranks found in the logs.
verbose bool False Enable verbose logging output.
split_inductor_compilations bool True Split output by frame_id, compile_id, attempt_id, and compiled_autograd_id.
skip_logger bool False Skip usage logging (internal use).

Returns:

  • Optional[str]: Path to output directory, or URL in fbcode environments.

Example - Basic parsing:

import tritonparse.utils

tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
)

Example - Full options:

import tritonparse.utils

tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    overwrite=True,
    verbose=True,
    all_ranks=True,
)

Example - Parse specific rank:

import tritonparse.utils

# Parse only rank 0
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    rank=0,
)

Example - Complete workflow:

import tritonparse.structured_logging
import tritonparse.utils

# 1. Initialize logging
tritonparse.structured_logging.init("./logs/", enable_trace_launch=True)

# 2. Run your kernels
# ... your kernel code ...

# 3. Parse the logs
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    overwrite=True,
)

# Output files are in ./parsed_output/

tritonparse.context_manager

Context manager for simplified trace-and-parse workflow.

TritonParseManager

A context manager that automatically handles initialization, trace collection, and parsing.

class TritonParseManager:
    def __init__(
        self,
        enable_trace_launch: bool = False,
        split_inductor_compilations: bool = True,
        enable_tensor_blob_storage: bool = False,
        tensor_storage_quota: Optional[int] = None,
        **parse_kwargs,
    )

Parameters:

Parameter Type Default Description
enable_trace_launch bool False Enable launch event tracing.
split_inductor_compilations bool True Split output by compilation IDs.
enable_tensor_blob_storage bool False Save tensor blob data.
tensor_storage_quota Optional[int] None Storage quota for tensor blobs.
**parse_kwargs Additional arguments passed to unified_parse().

Attributes:

Attribute Type Description
dir_path str Temporary directory for raw logs (available after __enter__).
output_link Optional[str] Path to parsed output (available after __exit__).

Example - Basic usage:

from tritonparse.context_manager import TritonParseManager

with TritonParseManager(enable_trace_launch=True) as manager:
    # Your kernel code here
    result = my_kernel(input_tensor)

# Logs are automatically parsed on exit
print(f"Parsed output: {manager.output_link}")

Example - With output directory:

from tritonparse.context_manager import TritonParseManager

with TritonParseManager(
    enable_trace_launch=True,
    out="./parsed_output",
    overwrite=True,
) as manager:
    # Run kernels
    output = compiled_function(a, b)

print(f"Output saved to: {manager.output_link}")

Example - With tensor blob storage:

from tritonparse.context_manager import TritonParseManager

with TritonParseManager(
    enable_trace_launch=True,
    enable_tensor_blob_storage=True,
    tensor_storage_quota=5 * 1024 * 1024 * 1024,  # 5GB
    out="./parsed_output",
) as manager:
    # Your kernels with tensor data preserved
    result = kernel(input_data)

Workflow:

  1. On __enter__: Creates temporary directory and calls init()
  2. Inside with block: Your code runs with tracing enabled
  3. On __exit__: Calls unified_parse() and clear_logging_config()

tritonparse.reproducer.orchestrator

Module for generating standalone Python scripts that reproduce kernel executions.

💡 See Reproducer Guide for comprehensive documentation including workflow overview, tensor reconstruction strategies, custom templates, and troubleshooting.

reproduce()

Generate a standalone reproducer script from a trace file.

def reproduce(
    input_path: str,
    line_index: int = 0,
    out_dir: Optional[str] = None,
    template: str = "example",
    kernel_name: Optional[str] = None,
    launch_id: int = 0,
    kernel_import: Optional[str] = None,
    replacer = None,
    skip_logger: bool = False,
) -> Dict[str, str]

Parameters:

Parameter Type Default Description
input_path str Required Path to trace file (.ndjson or .ndjson.gz).
line_index int 0 0-based line index of the event to reproduce. Line 0 is typically compilation, line 1+ are launches.
out_dir Optional[str] None Output directory. If None, creates repro_output/<kernel_name>/.
template str "example" Template name ("example", "tritonbench") or path to custom template file.
kernel_name Optional[str] None Target kernel name. If specified, finds the kernel by name instead of line index.
launch_id int 0 Launch instance ID when using kernel_name.
kernel_import Optional[str] None Custom import statement for the kernel.
replacer None Custom placeholder replacer (advanced use).
skip_logger bool False Skip usage logging.

Returns:

{
    "repro_script": "/path/to/repro_<timestamp>.py",
    "repro_context": "/path/to/repro_context_<timestamp>.json"
}

Example - Basic reproducer:

from tritonparse.reproducer.orchestrator import reproduce

result = reproduce(
    input_path="./parsed_output/trace.ndjson.gz",
    line_index=1,  # First launch event (0 is compilation)
    out_dir="./repro_output",
)

print(f"Reproducer script: {result['repro_script']}")
print(f"Context file: {result['repro_context']}")

Example - Using kernel name:

from tritonparse.reproducer.orchestrator import reproduce

result = reproduce(
    input_path="./trace.ndjson.gz",
    kernel_name="matmul_kernel",
    launch_id=0,  # First launch of this kernel
    out_dir="./repro_output",
)

Example - TritonBench template:

from tritonparse.reproducer.orchestrator import reproduce

result = reproduce(
    input_path="./trace.ndjson.gz",
    line_index=1,
    template="tritonbench",
    out_dir="./benchmark_repro",
)

Example - Custom template:

from tritonparse.reproducer.orchestrator import reproduce

result = reproduce(
    input_path="./trace.ndjson.gz",
    line_index=1,
    template="/path/to/my_template.py",
    out_dir="./custom_repro",
)

Available Templates:

Template Description
example Basic standalone reproducer with tensor reconstruction
tritonbench TritonBench-compatible benchmark operator

Custom Template Placeholders:

Placeholder Description
{{KERNEL_IMPORT_PLACEHOLDER}} Kernel import statements
{{KERNEL_INVOCATION_PLACEHOLDER}} Kernel launch code
{{KERNEL_SYSPATH_PLACEHOLDER}} System path setup
{{JSON_FILE_NAME_PLACEHOLDER}} Context JSON filename

tritonparse.info

Module for querying kernel information from trace files.

CLI: info command

Query kernel information from the command line.

tritonparseoss info <input_file> [options]

Arguments:

Argument Description
<input_file> Path to trace file (.ndjson or .ndjson.gz)

Options:

Option Description
--kernel <name> Query a specific kernel by name
--args-list Show detailed argument list

Examples:

# List all kernels in trace
tritonparseoss info ./trace.ndjson.gz

# Query specific kernel
tritonparseoss info ./trace.ndjson.gz --kernel matmul_kernel

# Show argument details
tritonparseoss info ./trace.ndjson.gz --kernel matmul_kernel --args-list

Python API: info_command()

from tritonparse.info.cli import info_command

info_command(
    input_path: str,
    kernel_name: Optional[str] = None,
    skip_logger: bool = False,
    args_list: bool = False,
)

Parameters:

Parameter Type Default Description
input_path str Required Path to trace file.
kernel_name Optional[str] None Specific kernel to query.
skip_logger bool False Skip usage logging.
args_list bool False Show argument list.

🎯 Quick Reference

Common Patterns

Pattern 1: Simple Tracing

import tritonparse.structured_logging
import tritonparse.utils

tritonparse.structured_logging.init("./logs/", enable_trace_launch=True)
# ... run kernels ...
tritonparse.utils.unified_parse("./logs/", out="./parsed_output")

Pattern 2: Context Manager

from tritonparse.context_manager import TritonParseManager

with TritonParseManager(enable_trace_launch=True, out="./output") as m:
    # ... run kernels ...
print(m.output_link)

Pattern 3: Generate Reproducer

from tritonparse.reproducer.orchestrator import reproduce

result = reproduce("./trace.ndjson.gz", line_index=1, out_dir="./repro")

Module Import Summary

# Initialization
import tritonparse.structured_logging
tritonparse.structured_logging.init(...)
tritonparse.structured_logging.init_with_env()
tritonparse.structured_logging.clear_logging_config()

# Parsing
import tritonparse.utils
tritonparse.utils.unified_parse(...)

# Context Manager
from tritonparse.context_manager import TritonParseManager

# Reproducer
from tritonparse.reproducer.orchestrator import reproduce

# Info (CLI primarily)
# tritonparseoss info <file> [options]

🔗 Related Documentation

Clone this wiki locally