Skip to content

Commit

Permalink
Update CPU backend params (#153)
Browse files Browse the repository at this point in the history
Without the change the test does not pass, fails to compile kernel call
  • Loading branch information
parsifal-47 committed Jul 26, 2024
1 parent ac135d7 commit d2bc9f2
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
11 changes: 9 additions & 2 deletions backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def _generate_launcher(constants, signature, kernel_name):
args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()])
format = "iiiOOOO" + args_format
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''

kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants)
kernel_arg_decls += ', ' if kernel_arg_decls else ''

kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if i not in constants)
kernel_parameters += ', ' if kernel_parameters else ''

return f"""
#include <assert.h>
#include <stdbool.h>
Expand All @@ -70,7 +77,7 @@ def _generate_launcher(constants, signature, kernel_name):
extern "C" {{
// Pointer type (=Memref) becomes int64_t + MemRef struct
// FIXME: understand what this int64_t is used for.
void {kernel_name}({', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants)},
void {kernel_name}({kernel_arg_decls}
int, int, int, int, int, int);
}}
Expand All @@ -82,7 +89,7 @@ def _generate_launcher(constants, signature, kernel_name):
for(int z = 0; z < gridZ; z++) {{
// Use some random type "char" here.
{' '.join(f'StridedMemRefType<char, 0> ptr_arg{i} = {{static_cast<char *>(arg{i}), static_cast<char *>(arg{i}), 0}};' for i, ty in signature.items() if i not in constants and ty[0] == "*")}
{kernel_name}({', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if i not in constants)},
{kernel_name}({kernel_parameters}
gridX, gridY, gridZ, x, y, z);
}}
}}
Expand Down
51 changes: 51 additions & 0 deletions python/examples/test_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations
import torch
import triton
from triton.backends.triton_shared.driver import CPUDriver
import triton.language as tl
import pytest

triton.runtime.driver.set_active(CPUDriver())

def annotated_function(return_type=None, **arg_types):
"""A decorator to add annotations to a function."""

def decorator(func):
func.__annotations__ = {**arg_types, 'return': return_type}
return func

return decorator


# Test integer annotations
@pytest.mark.parametrize(("signed", "width"), [
(signed, width) for signed in [False, True]\
for width in [8, 16, 32, 64]
] + [(False, 1)]
)
def test_int_annotation(signed, width):

@triton.jit
@annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}")
def _kernel(X, v):
tl.store(X, v)

h = _kernel[(1, )](torch.empty(1, device="cpu"), 3)
pfx = 'si' if signed else 'ui'
assert f'%arg1: i{width}' in h.asm["ttir"]
assert f'arith.{pfx}tofp' in h.asm["ttir"]


# Test that unknown annotations do not emit an error
def test_unknown_annotation():

@triton.jit
def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):
pass

x = torch.empty(1, device="cpu")
_kernel[(1, )](x, x.shape[0], 32)
try:
_kernel[(1, )](x.shape[0], x.shape[0], 32)
except AttributeError:
pass

0 comments on commit d2bc9f2

Please sign in to comment.