diff --git a/.github/workflows/aipu-build-and-test.yml b/.github/workflows/aipu-build-and-test.yml index e23028250..9fd20ee20 100644 --- a/.github/workflows/aipu-build-and-test.yml +++ b/.github/workflows/aipu-build-and-test.yml @@ -60,3 +60,5 @@ jobs: source ~/env_setup.sh python3.10 third_party/aipu/python/test/test_01_vector_add.py python3.10 third_party/aipu/python/test/test_02_fused_softmax.py + python3.10 third_party/aipu/python/test/test_libdevice_fmod.py + python3.10 third_party/aipu/python/test/test_libdevice_pow.py diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py index 00c86e5a4..dd4f98db3 100644 --- a/python/setup_tools/utils/__init__.py +++ b/python/setup_tools/utils/__init__.py @@ -20,7 +20,7 @@ class FlagTreeBackend: FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", tag="00f51c2e48a943922f86f03d58e29f514def646d"), FlagTreeBackend(name="flir", url="git@github.com:FlagTree/flir.git", - tag="e72b83ba46a5a9dd6466c7102f93fd600cde909e"), + tag="318ed13e396d4d0ed84773975c8507c6e3f0275d"), FlagTreeBackend( name="ascend", url="https://gitee.com/ascend/triton-ascend.git", diff --git a/third_party/aipu/backend/codegen.py b/third_party/aipu/backend/codegen.py index afac4ed08..7436d2c01 100644 --- a/third_party/aipu/backend/codegen.py +++ b/third_party/aipu/backend/codegen.py @@ -245,6 +245,9 @@ def dispatch(self, op, stage): elif op_name == "arith.select": self.gen_select(op) # Math Dialect + elif op_name == "mathext.fmod": + fmod = lambda x, y: T.call_extern(_get_type(op.result), "fmod", x, y) + self.gen_binary(op, fmod) elif op_name == "math.powf": self.gen_binary(op, S.pow) elif op_name == "math.tanh": diff --git a/third_party/aipu/language/aipu/libdevice.py b/third_party/aipu/language/aipu/libdevice.py index 0cf642ed7..4f71cba5d 100644 --- a/third_party/aipu/language/aipu/libdevice.py +++ b/third_party/aipu/language/aipu/libdevice.py @@ -807,6 +807,7 @@ def binary_op(arg0, arg1, _builder=None): return binary_op +fmod = create_binary_op_wrapper("fmod", ["fp32", "fp16"]) pow = create_binary_op_wrapper("powf", ["fp32", "fp16"]) tanh = create_unary_op_wrapper("tanh", ["fp32", "fp16"]) erf = create_unary_op_wrapper("erf", ["fp32", "fp16"]) diff --git a/third_party/aipu/python/test/test_libdevice_fmod.py b/third_party/aipu/python/test/test_libdevice_fmod.py new file mode 100644 index 000000000..d60689b27 --- /dev/null +++ b/third_party/aipu/python/test/test_libdevice_fmod.py @@ -0,0 +1,31 @@ +import torch +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit() +def test_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + z = tl.extra.aipu.libdevice.fmod(x, y) + tl.store(y_ptr + offsets, z, mask=mask) + + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, dtype=torch.float32, device=DEVICE) +output_triton = torch.rand(size, device=DEVICE) +n_elements = x.numel() +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) +test_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) diff --git a/third_party/aipu/python/test/test_libdevice_pow.py b/third_party/aipu/python/test/test_libdevice_pow.py new file mode 100644 index 000000000..0d0ea9e9a --- /dev/null +++ b/third_party/aipu/python/test/test_libdevice_pow.py @@ -0,0 +1,31 @@ +import torch +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit() +def test_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + z = tl.extra.aipu.libdevice.pow(x, y) + tl.store(y_ptr + offsets, z, mask=mask) + + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, dtype=torch.float32, device=DEVICE) +output_triton = torch.rand(size, device=DEVICE) +n_elements = x.numel() +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) +test_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)