From b680f650e074c36cad8648c98afed52dd011d430 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Mon, 21 Oct 2024 18:31:33 +0800 Subject: [PATCH 1/2] [Issue 192] Tail split support for dynamic matmul --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 8d73b0b34..b63fba41c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 8d73b0b3475d0046b01ae0d03f38ae59f53b7e69 +Subproject commit b63fba41c22d6a1edb0ef4832be272fac68c8968 From eac121babec413db101727ce7f41f0e8b750cc86 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Mon, 21 Oct 2024 18:40:15 +0800 Subject: [PATCH 2/2] [Issue 192] Tail split support for dynamic matmul & add test case --- .../tilelang/test_tilelang_dyanmic_symbolic.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index d0587ebef..0dfe07633 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -372,13 +372,18 @@ def assert_tl_matmul_block_all_dynamic_correctness( ) mod, params = TL.lower(program) - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + if trans_A: + A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + if trans_B: + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + else: + B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) mod(A, B, C) - print(mod.mod.imported_modules[0].get_source()) def ref_program(A, B): import torch @@ -419,6 +424,8 @@ def test_assert_tl_matmul_block_all_dynamic(): "float16", 64, 64, 32) assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(36, 115, 103, False, False, "float16", "float16", + "float16", 64, 64, 32) if __name__ == "__main__":