diff --git a/tritonparse/reproducer/templates/example.py b/tritonparse/reproducer/templates/example.py index ffdab21..15e6fcc 100644 --- a/tritonparse/reproducer/templates/example.py +++ b/tritonparse/reproducer/templates/example.py @@ -67,6 +67,10 @@ def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch RuntimeError: If the tensor cannot be loaded ValueError: If the computed hash doesn't match the filename hash """ + # Normalize cuda device to cuda:0 + if device is not None and isinstance(device, str) and device.startswith("cuda"): + device = "cuda:0" + blob_path = Path(tensor_file_path) if not blob_path.exists(): @@ -210,6 +214,9 @@ def _create_base_tensor(arg_info) -> torch.Tensor: shape = arg_info.get("shape", []) device = arg_info.get("device", "cpu") + # Normalize cuda device to cuda:0 + if isinstance(device, str) and device.startswith("cuda"): + device = "cuda:0" # Extract statistical information if available mean = arg_info.get("mean")