diff --git a/3rdparty/tvm b/3rdparty/tvm index 8d73b0b34..b63fba41c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 8d73b0b3475d0046b01ae0d03f38ae59f53b7e69 +Subproject commit b63fba41c22d6a1edb0ef4832be272fac68c8968 diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index d0587ebef..0dfe07633 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -372,13 +372,18 @@ def assert_tl_matmul_block_all_dynamic_correctness( ) mod, params = TL.lower(program) - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + if trans_A: + A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + if trans_B: + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + else: + B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) mod(A, B, C) - print(mod.mod.imported_modules[0].get_source()) def ref_program(A, B): import torch @@ -419,6 +424,8 @@ def test_assert_tl_matmul_block_all_dynamic(): "float16", 64, 64, 32) assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(36, 115, 103, False, False, "float16", "float16", + "float16", 64, 64, 32) if __name__ == "__main__":