diff --git a/tritonparse/ir_analysis.py b/tritonparse/ir_analysis.py new file mode 100644 index 0000000..20addd3 --- /dev/null +++ b/tritonparse/ir_analysis.py @@ -0,0 +1,72 @@ +import logging + +from .sourcemap_utils import load_ir_contents + + +logger = logging.getLogger("IRAnalysis") + + +def process_amd_bufferop(ir_content: str, io_keys: list[str]) -> dict[str, int]: + def make_key(prefix: str) -> str: + return f"{prefix}_count" + + io_keys = [(make_key(prefix), prefix) for prefix in io_keys] + output: dict[str, int] = {} + for dict_key, _ in io_keys: + output[dict_key] = 0 + if ir_content: + for line in ir_content.split("\n"): + for dict_key, code_key in io_keys: + if code_key in line: + output[dict_key] += 1 + return output + + +def process_amd_ttgir_bufferops( + key: str, + file_content: dict[str, str], + file_path: dict[str, str], +) -> dict[str, int]: + ir_content = load_ir_contents(key, file_content, file_path) + # TODO: Add atomics + io_keys = ["tt.load", "tt.store", "amdgpu.buffer_load", "amdgpu.buffer_store"] + return process_amd_bufferop(ir_content, io_keys) + + +def process_amd_gcn_bufferops( + key: str, + file_content: dict[str, str], + file_path: dict[str, str], +) -> dict[str, int]: + ir_content = load_ir_contents(key, file_content, file_path) + # TODO: Add atomics + io_keys = ["global_load_", "global_store_", "buffer_load_", "buffer_store_"] + return process_amd_bufferop(ir_content, io_keys) + + +def _generate_ir_analysis(entry: str): + payload = entry.setdefault("payload", {}) + file_content = payload.get("file_content", {}) + file_path = payload.get("file_path", {}) + + # Find the IR file keys + ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) + amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) + # Skip if no IR files found + if not (ttgir_key or amdgcn_key): + logger.debug("No AMD IR found") + return {} + ir_analysis = {} + if amdgcn_key: + ttgir_bufferops_info = process_amd_ttgir_bufferops( + ttgir_key, file_content, file_path + ) + gcn_bufferops_info = process_amd_gcn_bufferops( + amdgcn_key, file_content, file_path + ) + # NDJSON format requires a newline at the end of each line + if ttgir_bufferops_info: + ir_analysis["amd_ttgir_bufferops_count"] = ttgir_bufferops_info + if gcn_bufferops_info: + ir_analysis["amd_gcn_bufferops_count"] = gcn_bufferops_info + return {"ir_analysis": ir_analysis} diff --git a/tritonparse/sourcemap_utils.py b/tritonparse/sourcemap_utils.py index 52c94f7..e2e9b75 100644 --- a/tritonparse/sourcemap_utils.py +++ b/tritonparse/sourcemap_utils.py @@ -1,5 +1,8 @@ +import logging from typing import Any, Dict, List +logger = logging.getLogger("SourceMapping") + def get_file_extension(filename: str) -> str: """ @@ -70,3 +73,22 @@ def _to_ranges(indices: List[int]) -> List[Dict[str, int]]: ranges.append({"start": start, "end": end}) return ranges + + +def load_ir_contents( + key: str, + file_content: dict[str, str], + file_path: dict[str, str], +): + if not key: + return {} + logger.debug(f"Processing {key}") + ir_content = file_content.get(key, None) + if not ir_content: + ir_file_path = file_path.get(key, None) + if not ir_file_path: + logger.warning(f"No content found for {key}") + return {} + with open(ir_file_path, "r") as f: + ir_content = f.read() + return ir_content diff --git a/tritonparse/trace_processor.py b/tritonparse/trace_processor.py index 92b25b2..b321e6b 100644 --- a/tritonparse/trace_processor.py +++ b/tritonparse/trace_processor.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List from .event_diff import _generate_launch_diff +from .ir_analysis import _generate_ir_analysis from .ir_parser import ( extract_code_locations, @@ -13,7 +14,7 @@ extract_ptx_amdgcn_mappings, ) from .mapper import create_bidirectional_mapping, create_python_mapping -from .sourcemap_utils import get_file_extension +from .sourcemap_utils import get_file_extension, load_ir_contents logger = logging.getLogger("SourceMapping") @@ -84,19 +85,9 @@ def process_ir( file_path: Dict[str, str], other_mappings: List[Any] = None, ): - # Generate source mappings for each IR type - # the key should be the full file name with extension for the IR files - if not key: - return {} - logger.debug(f"Processing {key}") - ir_content = file_content.get(key, None) + ir_content = load_ir_contents(key, file_content, file_path) if not ir_content: - ir_file_path = file_path.get(key, None) - if not ir_file_path: - logger.warning(f"No content found for {key}") - return {} - with open(ir_file_path, "r") as f: - ir_content = f.read() + return {} mapping = generate_source_mappings(ir_content, key.split(".")[1], other_mappings) logger.debug(f"Generated source mapping for {key}") return mapping @@ -307,6 +298,13 @@ def parse_single_file( json.dumps(launch_event, separators=(",", ":")) + "\n" ) + if compilation_event: + ir_analysis_event = _generate_ir_analysis(compilation_event) + if ir_analysis_event: + all_output_lines[output_file].append( + json.dumps(ir_analysis_event, separators=(",", ":")) + "\n" + ) + if compilation_event and launches_with_indices: sames, diffs, launch_index_map = _generate_launch_diff( launches_with_indices