diff --git a/tritonparse/reproducer/function_extractor.py b/tritonparse/reproducer/function_extractor.py new file mode 100644 index 0000000..bfc0c63 --- /dev/null +++ b/tritonparse/reproducer/function_extractor.py @@ -0,0 +1,220 @@ +""" +Function extractor for reproducer utility functions. + +This module extracts utility functions from utils.py and load_tensor.py +using AST parsing, and generates standalone code for reproducers. +""" + +import ast +from pathlib import Path + + +def extract_utility_functions() -> str: + """ + Extract all utility functions needed for the reproducer template. + + Uses AST parsing to extract functions and constants from source files + without importing them (avoiding potential side effects). + + Returns: + str: Complete Python code including imports and all utility functions. + """ + # Prepare file paths + base_dir = Path(__file__).parent + utils_path = base_dir / "utils.py" + load_tensor_path = base_dir.parent / "tools" / "load_tensor.py" + + # Parse source files + utils_tree, utils_lines = _parse_source_file(utils_path) + load_tensor_tree, load_tensor_lines = _parse_source_file(load_tensor_path) + + # Define what to extract (in dependency order) + utils_function_names = [ + "_get_triton_tensor_types", + "create_args_from_json_file", + "create_args_from_json", + "_apply_stride_and_offset", + "_create_base_tensor", + "_create_tensor", + "_create_arg_from_info", + ] + + load_tensor_function_names = [ + "load_tensor", + ] + + # Extract content + extracted_parts = [] + + # Add required imports + extracted_parts.append(_generate_imports()) + + # Extract constant + constant = _extract_assignment( + utils_tree, utils_lines, "TRITON_KERNELS_CUSTOM_TYPES" + ) + if constant: + extracted_parts.append(constant) + + # Extract load_tensor functions + extracted_parts.extend( + _extract_functions( + load_tensor_tree, load_tensor_lines, load_tensor_function_names + ) + ) + + # Extract utils functions + extracted_parts.extend( + _extract_functions(utils_tree, utils_lines, utils_function_names) + ) + + # Combine all parts + return "\n\n".join(extracted_parts) + + +def _parse_source_file(file_path: Path) -> tuple[ast.Module, list[str]]: + """ + Parse a Python source file and return its AST and source lines. + + Args: + file_path: Path to the Python source file + + Returns: + tuple: (AST tree, list of source code lines) + + Raises: + FileNotFoundError: If the source file doesn't exist + SyntaxError: If the source file has syntax errors + """ + try: + source_code = file_path.read_text(encoding="utf-8") + tree = ast.parse(source_code, filename=str(file_path)) + except FileNotFoundError as e: + raise FileNotFoundError(f"Source file not found: {file_path}") from e + except SyntaxError as e: + raise SyntaxError(f"Failed to parse {file_path}: {e}") from e + + lines = source_code.splitlines() + return tree, lines + + +def _extract_assignment( + tree: ast.Module, lines: list[str], var_name: str +) -> str | None: + """ + Extract a module-level assignment statement by variable name. + + Args: + tree: AST tree of the source file + lines: Source code lines + var_name: Name of the variable to extract + + Returns: + Complete assignment statement source code, or None if not found + + Example: + Extracts: + TRITON_KERNELS_CUSTOM_TYPES = ( + importlib.util.find_spec("triton_kernels") is not None + and importlib.util.find_spec("triton_kernels.tensor") is not None + ) + """ + # Search only at module level + for node in tree.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == var_name: + # Found it! Extract source code using line numbers + start_line = node.lineno - 1 # Convert to 0-based index + end_line = node.end_lineno # Already suitable for slicing + assignment_lines = lines[start_line:end_line] + return "\n".join(assignment_lines) + return None + + +def _extract_function(tree: ast.Module, lines: list[str], func_name: str) -> str | None: + """ + Extract a function definition by name, including decorators. + + Args: + tree: AST tree of the source file + lines: Source code lines + func_name: Name of the function to extract + + Returns: + Complete function source code including decorators, or None if not found + + Example: + Extracts: + @lru_cache(maxsize=1) + def _get_triton_tensor_types(): + '''Docstring''' + ... + """ + # Walk the entire tree (handles nested functions if needed) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == func_name: + # If function has decorators, start from the first decorator + if node.decorator_list: + start_line = node.decorator_list[0].lineno - 1 + else: + start_line = node.lineno - 1 + + end_line = node.end_lineno + func_lines = lines[start_line:end_line] + return "\n".join(func_lines) + return None + + +def _extract_functions( + tree: ast.Module, lines: list[str], func_names: list[str] +) -> list[str]: + """ + Extract multiple functions from a source file. + + Args: + tree: AST tree of the source file + lines: Source code lines + func_names: List of function names to extract + + Returns: + List of function source codes in the same order as func_names + + Raises: + ValueError: If any function is not found + """ + extracted = [] + for func_name in func_names: + func_source = _extract_function(tree, lines, func_name) + if func_source is None: + raise ValueError( + f"Function '{func_name}' not found in source. " + f"Available functions might have been renamed or removed." + ) + extracted.append(func_source) + return extracted + + +def _generate_imports() -> str: + """ + Generate the import statements needed for the extracted functions. + + Returns: + str: Import statements as a single string + """ + imports = [ + "import gzip", + "import hashlib", + "import importlib", + "import importlib.util", + "import io", + "import json", + "import logging", + "import sys", + "from functools import lru_cache", + "from pathlib import Path", + "from typing import Union", + "", + "import torch", + ] + return "\n".join(imports) diff --git a/tritonparse/reproducer/placeholder_replacer.py b/tritonparse/reproducer/placeholder_replacer.py index ebdc1e1..734b078 100644 --- a/tritonparse/reproducer/placeholder_replacer.py +++ b/tritonparse/reproducer/placeholder_replacer.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Protocol +from tritonparse.reproducer.function_extractor import extract_utility_functions from tritonparse.reproducer.ingestion.ndjson import ContextBundle from tritonparse.reproducer.types import KernelImportMode from tritonparse.reproducer.utils import ( @@ -82,6 +83,9 @@ def __init__(self): ) self.register("# {{KERNEL_SYSPATH_PLACEHOLDER}}", self._replace_kernel_syspath) self.register("# {{KERNEL_IMPORT_PLACEHOLDER}}", self._replace_kernel_import) + self.register( + "# {{UTILITY_FUNCTIONS_PLACEHOLDER}}", self._replace_utility_functions + ) self.register( "# {{KERNEL_INVOCATION_PLACEHOLDER}}", self._replace_kernel_invocation ) @@ -217,6 +221,13 @@ def _replace_kernel_import( else: raise ValueError(f"Unknown kernel_import mode: {kernel_import}") + def _replace_utility_functions( + self, code: str, context_bundle: ContextBundle, **kwargs + ) -> str: + """Replace the utility functions placeholder with extracted functions.""" + utility_code = extract_utility_functions() + return code.replace("# {{UTILITY_FUNCTIONS_PLACEHOLDER}}", utility_code) + def _replace_kernel_invocation( self, code: str, context_bundle: ContextBundle, **kwargs ) -> str: diff --git a/tritonparse/reproducer/templates/example.py b/tritonparse/reproducer/templates/example.py index 15e6fcc..4a23b57 100644 --- a/tritonparse/reproducer/templates/example.py +++ b/tritonparse/reproducer/templates/example.py @@ -3,18 +3,6 @@ It contains a smallest testing example for a Triton kernel. """ -import gzip -import hashlib -import importlib -import importlib.util -import io -import json -import logging -import sys -from functools import lru_cache -from pathlib import Path -from typing import Union - import torch # {{IR_OVERRIDE_SETUP_PLACEHOLDER}} @@ -23,368 +11,13 @@ # {{KERNEL_IMPORT_PLACEHOLDER}} -TRITON_KERNELS_CUSTOM_TYPES = ( - importlib.util.find_spec("triton_kernels") is not None - and importlib.util.find_spec("triton_kernels.tensor") is not None -) - - -@lru_cache(maxsize=1) -def _get_triton_tensor_types(): - """ - Import and cache Triton custom tensor types. - - Returns: - tuple: (Tensor, Storage, StridedLayout) classes from triton_kernels.tensor. - - Raises: - ImportError: If the optional module 'triton_kernels.tensor' is not available. - """ - mod = importlib.import_module("triton_kernels.tensor") - return ( - mod.Tensor, - mod.Storage, - mod.StridedLayout, - ) - - -def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch.Tensor: - """ - Load a tensor from its file path and verify its integrity using the hash in the filename. - - Args: - tensor_file_path (str | Path): Direct path to the tensor file. Supports both: - - .bin.gz: gzip-compressed tensor (hash is of uncompressed data) - - .bin: uncompressed tensor (for backward compatibility) - device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu'). - If None, keeps the tensor on its original device. - - Returns: - torch.Tensor: The loaded tensor (moved to the specified device if provided) - - Raises: - FileNotFoundError: If the tensor file doesn't exist - RuntimeError: If the tensor cannot be loaded - ValueError: If the computed hash doesn't match the filename hash - """ - # Normalize cuda device to cuda:0 - if device is not None and isinstance(device, str) and device.startswith("cuda"): - device = "cuda:0" - - blob_path = Path(tensor_file_path) - - if not blob_path.exists(): - raise FileNotFoundError(f"Tensor blob not found: {blob_path}") - - # Detect compression by file extension - is_compressed = blob_path.name.endswith(".bin.gz") - - # Read file contents (decompress if needed) - try: - with open(blob_path, "rb") as f: - file_obj = gzip.GzipFile(fileobj=f, mode="rb") if is_compressed else f - file_contents = file_obj.read() - except (OSError, gzip.BadGzipFile) as e: - if is_compressed: - raise RuntimeError(f"Failed to decompress gzip file {blob_path}: {str(e)}") - else: - raise RuntimeError(f"Failed to read file {blob_path}: {str(e)}") - - # Extract expected hash from filename - # abc123.bin.gz -> abc123 or abc123.bin -> abc123 - expected_hash = blob_path.name.removesuffix(".bin.gz" if is_compressed else ".bin") - - # Compute hash of uncompressed data - computed_hash = hashlib.blake2b(file_contents).hexdigest() - - # Verify hash matches filename - if computed_hash != expected_hash: - raise ValueError( - f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'" - ) - - try: - # Load the tensor from memory buffer - tensor = torch.load(io.BytesIO(file_contents), map_location=device) - return tensor - except Exception as e: - raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}") - - -def create_args_from_json_file(json_path): - with open(json_path, "r") as f: - data = json.load(f) - return create_args_from_json(data) - - -def create_args_from_json(data): - """ - Parse a reproducer JSON and build kernel grid and argument dictionary. - - Args: - json_path (str): Path to the JSON file describing the kernel launch. - - Returns: - tuple[list, dict]: Grid specification list and map of argument name to value. - """ - # Handle data format validation and extraction - if isinstance(data, list): - if len(data) != 1: - print( - f"Error: Expected single element list, got list with {len(data)} elements" - ) - sys.exit(1) - data = data[0] - elif not isinstance(data, dict): - print(f"Error: Expected list or dict, got {type(data)}") - sys.exit(1) - - grid = data.get("grid", []) - args_dict = {} - extracted_args = data.get("extracted_args", {}) - - for arg_name, arg_info in extracted_args.items(): - args_dict[arg_name] = _create_arg_from_info(arg_info) - - return grid, args_dict - - -def _apply_stride_and_offset(tensor, shape, stride, storage_offset): - """ - Apply custom stride and storage offset to a tensor if needed. - - Args: - tensor: The base contiguous tensor - shape: The desired shape - stride: The desired stride (or None for contiguous) - storage_offset: The desired storage offset - - Returns: - torch.Tensor: The strided tensor view or original tensor if contiguous - """ - if stride is None: - return tensor - - # Calculate expected contiguous stride - expected_contiguous_stride = [] - s = 1 - for dim_size in reversed(shape): - expected_contiguous_stride.insert(0, s) - s *= dim_size - - # If stride matches contiguous stride and no storage offset, return as-is - if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0: - return tensor - - # Calculate required storage size - if len(shape) > 0 and len(stride) > 0: - max_offset = storage_offset - for dim_stride, dim_size in zip(stride, shape): - if dim_size > 0: - max_offset += dim_stride * (dim_size - 1) - storage_size = max_offset + 1 - else: - storage_size = storage_offset + 1 - - # Create larger storage tensor and create strided view - storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device) - - # Create strided view - strided_view = storage_tensor.as_strided( - size=shape, stride=stride, storage_offset=storage_offset - ) - - # Copy data from the base tensor into the strided layout - strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape)) - - return strided_view - - -def _create_base_tensor(arg_info) -> torch.Tensor: - if arg_info.get("blob_path"): - return load_tensor(arg_info.get("blob_path"), arg_info.get("device")) - - # Extract basic tensor properties - dtype_str = arg_info.get("dtype") - try: - torch_dtype = getattr(torch, dtype_str.split(".")[-1]) - except AttributeError: - logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.") - torch_dtype = torch.float32 - - shape = arg_info.get("shape", []) - device = arg_info.get("device", "cpu") - # Normalize cuda device to cuda:0 - if isinstance(device, str) and device.startswith("cuda"): - device = "cuda:0" - - # Extract statistical information if available - mean = arg_info.get("mean") - std = arg_info.get("std") - min_val = arg_info.get("min") - max_val = arg_info.get("max") - has_stats = ( - mean is not None - and std is not None - and min_val is not None - and max_val is not None - ) - - if arg_info.get("tensor_capture_error", False): - logging.error( - f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead." - ) - - # Use a dummy tensor to check properties of the dtype - tensor_props = torch.empty(0, dtype=torch_dtype) - - # Case 1: Floating point types - if tensor_props.is_floating_point(): - if has_stats: - # Generate tensor with statistical properties matching original data - if std == 0 or min_val == max_val: - # Constant tensor - return torch.full(shape, mean, dtype=torch_dtype, device=device) - # Generate normal distribution with mean and std, then clamp to [min, max] - tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean - tensor = torch.clamp(tensor, min=min_val, max=max_val) - return tensor.to(torch_dtype) - else: - # Fallback to original random generation - if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - tmp = torch.rand(shape, dtype=torch.float32, device=device) - return tmp.to(torch_dtype) - else: - return torch.empty(shape, dtype=torch_dtype, device=device).random_() - - # Case 2: Integer types - elif torch_dtype in [ - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.bool, - ]: - if has_stats and torch_dtype != torch.bool: - # Generate tensor with statistical properties, then round for integers - if std == 0 or min_val == max_val: - # Constant tensor - return torch.full(shape, int(mean), dtype=torch_dtype, device=device) - tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean - tensor = torch.clamp(tensor, min=min_val, max=max_val) - return torch.round(tensor).to(torch_dtype) - else: - # Fallback to original random generation - return torch.empty(shape, dtype=torch_dtype, device=device).random_() - - # Case 3: Complex numbers need special handling - elif tensor_props.is_complex(): - # Complex types: fallback to original logic for now - # TODO: Could be improved to use statistical info if available - float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64 - real_part = torch.rand(shape, dtype=float_dtype, device=device) - imag_part = torch.rand(shape, dtype=float_dtype, device=device) - return torch.complex(real_part, imag_part) - - # Case 4: Handle other unsigned integers (like uint32) which fail with random_() - elif "uint" in str(torch_dtype): - if has_stats: - # Generate tensor with statistical properties for unsigned integers - if std == 0 or min_val == max_val: - return torch.full(shape, int(mean), dtype=torch_dtype, device=device) - tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean - tensor = torch.clamp(tensor, min=min_val, max=max_val) - return torch.round(tensor).to(torch_dtype) - else: - # Fallback to original random generation - return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device) - - # Case 5: If we don't know how to handle the type, raise an error - else: - raise NotImplementedError( - f"Random data generation not implemented for dtype: {torch_dtype}" - ) - - -def _create_tensor(arg_info) -> torch.Tensor: - tensor = _create_base_tensor(arg_info) - - # Apply stride and storage offset if needed - shape = arg_info.get("shape", []) - stride = arg_info.get("stride") - storage_offset = arg_info.get("storage_offset", 0) - return _apply_stride_and_offset(tensor, shape, stride, storage_offset) - - -def _create_arg_from_info(arg_info): - """ - Recursively construct a kernel argument from its JSON schema. - - Args: - arg_info (dict): JSON object describing a single argument, including - fields like 'type', 'value', 'dtype', 'shape', 'device', etc. - - Returns: - Any: The constructed Python object suitable for kernel invocation. - - Raises: - RuntimeError: When required optional dependencies are missing. - NotImplementedError: When a dtype or type is not supported yet. - """ - arg_type = arg_info.get("type") - - if arg_type == "NoneType": - return None - - if arg_type in ["int", "bool", "str", "float"]: - return arg_info.get("value") - - elif arg_type == "tensor": - return _create_tensor(arg_info) - - elif arg_type == "triton_kernels.tensor.Tensor": - if not TRITON_KERNELS_CUSTOM_TYPES: - raise RuntimeError( - "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Tensor." - ) - Tensor, Storage, StridedLayout = _get_triton_tensor_types() - storage = _create_arg_from_info(arg_info.get("storage")) - dtype_str = arg_info.get("dtype") - torch_dtype = getattr(torch, dtype_str.split(".")[-1]) - return Tensor( - storage=storage, - shape=arg_info.get("shape"), - shape_max=arg_info.get("shape_max"), - dtype=torch_dtype, - ) - - elif arg_type == "triton_kernels.tensor.Storage": - if not TRITON_KERNELS_CUSTOM_TYPES: - raise RuntimeError( - "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct Storage." - ) - Tensor, Storage, StridedLayout = _get_triton_tensor_types() - data = _create_arg_from_info(arg_info.get("data")) - layout = _create_arg_from_info(arg_info.get("layout")) - return Storage(data=data, layout=layout) - - elif arg_type == "StridedLayout": - if not TRITON_KERNELS_CUSTOM_TYPES: - raise RuntimeError( - "Optional dependency 'triton_kernels.tensor' is not installed; cannot construct StridedLayout." - ) - Tensor, Storage, StridedLayout = _get_triton_tensor_types() - return StridedLayout(shape=arg_info.get("initial_shape")) - else: - print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.") - return None +# {{UTILITY_FUNCTIONS_PLACEHOLDER}} if __name__ == "__main__": - script_dir = Path(__file__).resolve().parent + script_dir = Path(__file__).resolve().parent # noqa: F821 json_file = script_dir / "{{JSON_FILE_NAME_PLACEHOLDER}}" - grid, args_dict = create_args_from_json_file(str(json_file)) + grid, args_dict = create_args_from_json_file(str(json_file)) # noqa: F821 print("Generated kernel arguments dictionary:") for name, arg in args_dict.items(): diff --git a/tritonparse/reproducer/utils.py b/tritonparse/reproducer/utils.py index 2d9232e..5657db7 100644 --- a/tritonparse/reproducer/utils.py +++ b/tritonparse/reproducer/utils.py @@ -1,6 +1,7 @@ import importlib import importlib.util import json +import logging import sys from datetime import datetime from functools import lru_cache @@ -27,9 +28,9 @@ def _get_triton_tensor_types(): ) -def create_args_from_json(json_path): +def create_args_from_json_file(json_path): """ - Parse a reproducer JSON and build kernel grid and argument dictionary. + Load and parse a reproducer JSON file. Args: json_path (str): Path to the JSON file describing the kernel launch. @@ -39,6 +40,19 @@ def create_args_from_json(json_path): """ with open(json_path, "r") as f: data = json.load(f) + return create_args_from_json(data) + + +def create_args_from_json(data): + """ + Parse a reproducer JSON and build kernel grid and argument dictionary. + + Args: + data (dict | list): JSON data describing the kernel launch. + + Returns: + tuple[list, dict]: Grid specification list and map of argument name to value. + """ # Handle data format validation and extraction if isinstance(data, list): if len(data) != 1: @@ -61,6 +75,192 @@ def create_args_from_json(json_path): return grid, args_dict +def _apply_stride_and_offset(tensor, shape, stride, storage_offset): + """ + Apply custom stride and storage offset to a tensor if needed. + + Args: + tensor: The base contiguous tensor + shape: The desired shape + stride: The desired stride (or None for contiguous) + storage_offset: The desired storage offset + + Returns: + torch.Tensor: The strided tensor view or original tensor if contiguous + """ + if stride is None: + return tensor + + # Calculate expected contiguous stride + expected_contiguous_stride = [] + s = 1 + for dim_size in reversed(shape): + expected_contiguous_stride.insert(0, s) + s *= dim_size + + # If stride matches contiguous stride and no storage offset, return as-is + if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0: + return tensor + + # Calculate required storage size + if len(shape) > 0 and len(stride) > 0: + max_offset = storage_offset + for dim_stride, dim_size in zip(stride, shape): + if dim_size > 0: + max_offset += dim_stride * (dim_size - 1) + storage_size = max_offset + 1 + else: + storage_size = storage_offset + 1 + + # Create larger storage tensor and create strided view + storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device) + + # Create strided view + strided_view = storage_tensor.as_strided( + size=shape, stride=stride, storage_offset=storage_offset + ) + + # Copy data from the base tensor into the strided layout + strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape)) + + return strided_view + + +def _create_base_tensor(arg_info) -> torch.Tensor: + """ + Create a base tensor without stride/offset modifications. + + Args: + arg_info (dict): Argument information including dtype, shape, device, etc. + + Returns: + torch.Tensor: The created base tensor + """ + if arg_info.get("blob_path"): + return load_tensor(arg_info.get("blob_path"), arg_info.get("device")) + + # Extract basic tensor properties + dtype_str = arg_info.get("dtype") + try: + torch_dtype = getattr(torch, dtype_str.split(".")[-1]) + except AttributeError: + logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.") + torch_dtype = torch.float32 + + shape = arg_info.get("shape", []) + device = arg_info.get("device", "cpu") + # Normalize cuda device to cuda:0 + if isinstance(device, str) and device.startswith("cuda"): + device = "cuda:0" + + # Extract statistical information if available + mean = arg_info.get("mean") + std = arg_info.get("std") + min_val = arg_info.get("min") + max_val = arg_info.get("max") + has_stats = ( + mean is not None + and std is not None + and min_val is not None + and max_val is not None + ) + + if arg_info.get("tensor_capture_error", False): + logging.error( + f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead." + ) + + # Use a dummy tensor to check properties of the dtype + tensor_props = torch.empty(0, dtype=torch_dtype) + + # Case 1: Floating point types + if tensor_props.is_floating_point(): + if has_stats: + # Generate tensor with statistical properties matching original data + if std == 0 or min_val == max_val: + # Constant tensor + return torch.full(shape, mean, dtype=torch_dtype, device=device) + # Generate normal distribution with mean and std, then clamp to [min, max] + tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean + tensor = torch.clamp(tensor, min=min_val, max=max_val) + return tensor.to(torch_dtype) + else: + # Fallback to original random generation + if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + tmp = torch.rand(shape, dtype=torch.float32, device=device) + return tmp.to(torch_dtype) + else: + return torch.empty(shape, dtype=torch_dtype, device=device).random_() + + # Case 2: Integer types + elif torch_dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.bool, + ]: + if has_stats and torch_dtype != torch.bool: + # Generate tensor with statistical properties, then round for integers + if std == 0 or min_val == max_val: + # Constant tensor + return torch.full(shape, int(mean), dtype=torch_dtype, device=device) + tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean + tensor = torch.clamp(tensor, min=min_val, max=max_val) + return torch.round(tensor).to(torch_dtype) + else: + # Fallback to original random generation + return torch.empty(shape, dtype=torch_dtype, device=device).random_() + + # Case 3: Complex numbers need special handling + elif tensor_props.is_complex(): + # Complex types: fallback to original logic for now + # TODO: Could be improved to use statistical info if available + float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64 + real_part = torch.rand(shape, dtype=float_dtype, device=device) + imag_part = torch.rand(shape, dtype=float_dtype, device=device) + return torch.complex(real_part, imag_part) + + # Case 4: Handle other unsigned integers (like uint32) which fail with random_() + elif "uint" in str(torch_dtype): + if has_stats: + # Generate tensor with statistical properties for unsigned integers + if std == 0 or min_val == max_val: + return torch.full(shape, int(mean), dtype=torch_dtype, device=device) + tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean + tensor = torch.clamp(tensor, min=min_val, max=max_val) + return torch.round(tensor).to(torch_dtype) + else: + # Fallback to original random generation + return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device) + + # Case 5: If we don't know how to handle the type, raise an error + else: + raise NotImplementedError( + f"Random data generation not implemented for dtype: {torch_dtype}" + ) + + +def _create_tensor(arg_info) -> torch.Tensor: + """ + Create a tensor with stride and storage offset if needed. + + Args: + arg_info (dict): Argument information including dtype, shape, stride, etc. + + Returns: + torch.Tensor: The created tensor with applied stride/offset + """ + tensor = _create_base_tensor(arg_info) + + # Apply stride and storage offset if needed + shape = arg_info.get("shape", []) + stride = arg_info.get("stride") + storage_offset = arg_info.get("storage_offset", 0) + return _apply_stride_and_offset(tensor, shape, stride, storage_offset) + + def _create_arg_from_info(arg_info): """ Recursively construct a kernel argument from its JSON schema. @@ -78,120 +278,14 @@ def _create_arg_from_info(arg_info): """ arg_type = arg_info.get("type") - if arg_type in ["int", "bool"]: + if arg_type == "NoneType": + return None + + if arg_type in ["int", "bool", "str", "float"]: return arg_info.get("value") elif arg_type == "tensor": - if arg_info.get("blob_path"): - return load_tensor(arg_info.get("blob_path"), arg_info.get("device")) - - # Extract basic tensor properties - dtype_str = arg_info.get("dtype") - try: - torch_dtype = getattr(torch, dtype_str.split(".")[-1]) - except AttributeError: - logger.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.") - torch_dtype = torch.float32 - - shape = arg_info.get("shape", []) - device = arg_info.get("device", "cpu") - - # Extract statistical information if available - mean = arg_info.get("mean") - std = arg_info.get("std") - min_val = arg_info.get("min") - max_val = arg_info.get("max") - has_stats = ( - mean is not None - and std is not None - and min_val is not None - and max_val is not None - ) - - # Use a dummy tensor to check properties of the dtype - tensor_props = torch.empty(0, dtype=torch_dtype) - - # Case 1: Floating point types - if tensor_props.is_floating_point(): - if has_stats: - # Generate tensor with statistical properties matching original data - if std == 0 or min_val == max_val: - # Constant tensor - return torch.full(shape, mean, dtype=torch_dtype, device=device) - # Generate normal distribution with mean and std, then clamp to [min, max] - tensor = ( - torch.randn(shape, dtype=torch.float32, device=device) * std + mean - ) - tensor = torch.clamp(tensor, min=min_val, max=max_val) - return tensor.to(torch_dtype) - else: - # Fallback to original random generation - if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - tmp = torch.rand(shape, dtype=torch.float32, device=device) - return tmp.to(torch_dtype) - else: - return torch.empty( - shape, dtype=torch_dtype, device=device - ).random_() - - # Case 2: Integer types - elif torch_dtype in [ - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.bool, - ]: - if has_stats and torch_dtype != torch.bool: - # Generate tensor with statistical properties, then round for integers - if std == 0 or min_val == max_val: - # Constant tensor - return torch.full( - shape, int(mean), dtype=torch_dtype, device=device - ) - tensor = ( - torch.randn(shape, dtype=torch.float32, device=device) * std + mean - ) - tensor = torch.clamp(tensor, min=min_val, max=max_val) - return torch.round(tensor).to(torch_dtype) - else: - # Fallback to original random generation - return torch.empty(shape, dtype=torch_dtype, device=device).random_() - - # Case 3: Complex numbers need special handling - elif tensor_props.is_complex(): - # Complex types: fallback to original logic for now - # TODO: Could be improved to use statistical info if available - float_dtype = ( - torch.float32 if torch_dtype == torch.complex64 else torch.float64 - ) - real_part = torch.rand(shape, dtype=float_dtype, device=device) - imag_part = torch.rand(shape, dtype=float_dtype, device=device) - return torch.complex(real_part, imag_part) - - # Case 4: Handle other unsigned integers (like uint32) which fail with random_() - elif "uint" in str(torch_dtype): - if has_stats: - # Generate tensor with statistical properties for unsigned integers - if std == 0 or min_val == max_val: - return torch.full( - shape, int(mean), dtype=torch_dtype, device=device - ) - tensor = ( - torch.randn(shape, dtype=torch.float32, device=device) * std + mean - ) - tensor = torch.clamp(tensor, min=min_val, max=max_val) - return torch.round(tensor).to(torch_dtype) - else: - # Fallback to original random generation - return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device) - - # Case 5: If we don't know how to handle the type, raise an error - else: - raise NotImplementedError( - f"Random data generation not implemented for dtype: {torch_dtype}" - ) + return _create_tensor(arg_info) elif arg_type == "triton_kernels.tensor.Tensor": if not TRITON_KERNELS_CUSTOM_TYPES: