Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operators] Improving fp32 matrix multiplication on x86 CPUs #378

Merged
merged 141 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 138 commits
Commits
Show all changes
141 commits
Select commit Hold shift + click to select a range
efe3e14
.
BolinSNLHM May 28, 2023
b19a212
Merge branch 'hidet-org:main' into main
BolinSNLHM May 29, 2023
d7e4043
.
BolinSNLHM Jun 21, 2023
e13af0a
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM Jul 17, 2023
a7bce75
added basic openMP primitives
BolinSNLHM Jul 17, 2023
bad483c
Merge branch 'main' into omp
BolinSNLHM Aug 8, 2023
d7f6469
added those primitives back
BolinSNLHM Aug 8, 2023
f211a48
let me pretend like it's all good for tonight
BolinSNLHM Aug 13, 2023
bbb5afc
...
BolinSNLHM Aug 13, 2023
569fb49
working on refactoring
BolinSNLHM Aug 17, 2023
b32ea73
ready to be tested on the eco server
BolinSNLHM Aug 20, 2023
dbbb2b6
fix stupid error
BolinSNLHM Aug 20, 2023
014f5c1
..
BolinSNLHM Aug 20, 2023
2d82325
fix more error
BolinSNLHM Aug 21, 2023
11c9e70
..
BolinSNLHM Aug 21, 2023
4586e89
fixing hidet script error
BolinSNLHM Aug 23, 2023
65c3b9d
...:
BolinSNLHM Aug 23, 2023
286c107
....
BolinSNLHM Aug 23, 2023
bfacaf8
...
BolinSNLHM Aug 23, 2023
8246466
..
BolinSNLHM Aug 23, 2023
7518042
..
BolinSNLHM Aug 23, 2023
f8a97b2
fixing strange error
BolinSNLHM Aug 23, 2023
1a87c27
more errors
BolinSNLHM Aug 23, 2023
3104473
more err
BolinSNLHM Aug 23, 2023
68bc03d
...
BolinSNLHM Aug 23, 2023
9059ca3
...
BolinSNLHM Aug 23, 2023
df5a177
global
BolinSNLHM Aug 23, 2023
27da1ba
global var
BolinSNLHM Aug 25, 2023
fca3694
.
BolinSNLHM Aug 25, 2023
14973b4
.
BolinSNLHM Aug 25, 2023
45ad16a
...
BolinSNLHM Aug 25, 2023
36b3c52
..:
BolinSNLHM Aug 25, 2023
c79fcca
cast
BolinSNLHM Aug 25, 2023
87cdd76
cast
BolinSNLHM Aug 25, 2023
8648ced
...
BolinSNLHM Aug 25, 2023
075cc64
.
BolinSNLHM Aug 25, 2023
7814d6d
now segfault not internal errors
BolinSNLHM Aug 25, 2023
ff058bf
stupid error
BolinSNLHM Aug 26, 2023
f9f3b81
err
BolinSNLHM Aug 26, 2023
0a7b2fe
...
BolinSNLHM Aug 26, 2023
99954e1
..
BolinSNLHM Aug 26, 2023
b884a95
..
BolinSNLHM Aug 27, 2023
12a139a
.
BolinSNLHM Aug 27, 2023
8cf009d
.
BolinSNLHM Aug 27, 2023
717069f
...
BolinSNLHM Aug 27, 2023
7b53554
.
BolinSNLHM Aug 27, 2023
f933711
small fix
BolinSNLHM Aug 27, 2023
42054a4
..
BolinSNLHM Aug 27, 2023
60599c2
..
BolinSNLHM Aug 27, 2023
2d65005
.
BolinSNLHM Aug 27, 2023
747508b
.
BolinSNLHM Aug 27, 2023
4e5c7da
.
BolinSNLHM Aug 27, 2023
23f2768
try single thread first
BolinSNLHM Aug 27, 2023
0ab4888
..
BolinSNLHM Aug 27, 2023
1631d77
dumb mistake again
BolinSNLHM Aug 27, 2023
62c075c
..
BolinSNLHM Aug 27, 2023
5d4a314
..
BolinSNLHM Aug 27, 2023
e30ab31
keep debugging
BolinSNLHM Aug 28, 2023
134a1d5
..
BolinSNLHM Aug 28, 2023
e1e2d29
..
BolinSNLHM Aug 28, 2023
7a7ff5e
.
BolinSNLHM Aug 28, 2023
29de46f
..
BolinSNLHM Aug 28, 2023
ca9e67d
...
BolinSNLHM Aug 28, 2023
43d4a60
..:
BolinSNLHM Aug 28, 2023
3d67673
.
BolinSNLHM Aug 28, 2023
3c9d792
..
BolinSNLHM Aug 29, 2023
6782047
.
BolinSNLHM Aug 29, 2023
9401c1e
..
BolinSNLHM Aug 29, 2023
e655035
..
BolinSNLHM Aug 29, 2023
4c7ed70
..
BolinSNLHM Aug 29, 2023
21978bb
..
BolinSNLHM Aug 29, 2023
c90991f
..
BolinSNLHM Aug 29, 2023
7c3ef0a
continue fixing
BolinSNLHM Aug 29, 2023
4acf6c0
..
BolinSNLHM Aug 29, 2023
c740a3a
.
BolinSNLHM Aug 29, 2023
8f0ee0e
...
BolinSNLHM Aug 29, 2023
01e84ec
...
BolinSNLHM Aug 29, 2023
90505e7
..
BolinSNLHM Aug 29, 2023
805959e
...
BolinSNLHM Aug 29, 2023
8bb52d3
..
BolinSNLHM Aug 29, 2023
94abfa7
..
BolinSNLHM Aug 29, 2023
a3f35dc
.
BolinSNLHM Aug 29, 2023
230e6d0
..
BolinSNLHM Aug 29, 2023
e3bf60a
..
BolinSNLHM Aug 29, 2023
e5e4466
.
BolinSNLHM Aug 29, 2023
601e6b2
.
BolinSNLHM Aug 29, 2023
2df7355
..
BolinSNLHM Aug 29, 2023
ee30078
bruh
BolinSNLHM Aug 29, 2023
cb54a7e
..
BolinSNLHM Aug 29, 2023
8e07dad
.
BolinSNLHM Aug 29, 2023
8723df6
.
BolinSNLHM Aug 29, 2023
0919d12
..
BolinSNLHM Aug 29, 2023
b2a6c15
..
BolinSNLHM Aug 29, 2023
43922bb
..
BolinSNLHM Aug 29, 2023
553dfc4
..
BolinSNLHM Aug 29, 2023
ae29fb3
...
BolinSNLHM Aug 29, 2023
0572ace
.
BolinSNLHM Aug 29, 2023
ce1f5fd
.
BolinSNLHM Aug 29, 2023
aaa500c
..
BolinSNLHM Aug 29, 2023
6445811
.
BolinSNLHM Aug 29, 2023
d3e1a1d
.
BolinSNLHM Aug 29, 2023
6589848
..
BolinSNLHM Aug 29, 2023
4bc93c8
.
BolinSNLHM Aug 29, 2023
563b121
.
BolinSNLHM Aug 29, 2023
17011a1
.
BolinSNLHM Aug 29, 2023
18f8b53
..
BolinSNLHM Aug 29, 2023
12e44c2
..
BolinSNLHM Aug 29, 2023
ceb22dd
..
BolinSNLHM Aug 29, 2023
68fbba8
..
BolinSNLHM Aug 29, 2023
0c3639f
.
BolinSNLHM Aug 30, 2023
76d55a1
..
BolinSNLHM Aug 30, 2023
9e289e4
..
BolinSNLHM Aug 30, 2023
165c3d5
..
BolinSNLHM Aug 30, 2023
e898772
..
BolinSNLHM Aug 30, 2023
4cb35cb
.
BolinSNLHM Aug 30, 2023
6ba8075
..
BolinSNLHM Aug 30, 2023
073266a
.
BolinSNLHM Aug 30, 2023
d736d96
..
BolinSNLHM Aug 30, 2023
83118f3
.
BolinSNLHM Aug 30, 2023
df1cc83
....
BolinSNLHM Aug 30, 2023
a85e56f
..
BolinSNLHM Aug 30, 2023
728ec9a
kept debugging the matrix mul kernel
BolinSNLHM Oct 6, 2023
dfdf084
bruh
BolinSNLHM Oct 27, 2023
d2e1ab4
fixed a dumb bug that got me stuck for way too much longer than neces…
BolinSNLHM Nov 9, 2023
0c0efe0
.
BolinSNLHM Nov 9, 2023
1bd2cfe
remove prints
BolinSNLHM Nov 9, 2023
6721ed2
.
BolinSNLHM Nov 9, 2023
442fbd2
..
BolinSNLHM Nov 9, 2023
b4e00e9
logic error fix in packing of A
BolinSNLHM Nov 9, 2023
ad9c453
seems like still bugs, but they disappear with print...
BolinSNLHM Nov 10, 2023
d34f031
fix bug caused by static local vairable
BolinSNLHM Nov 11, 2023
954da89
...
BolinSNLHM Nov 15, 2023
78d09c4
fix alignment
BolinSNLHM Nov 15, 2023
838a61e
cleanup
BolinSNLHM Nov 17, 2023
6f572a4
Merge branch 'fix-zero-init' into main
BolinSNLHM Nov 17, 2023
3fbb635
ready for PR
BolinSNLHM Nov 17, 2023
656bbd0
......
BolinSNLHM Nov 17, 2023
ebcc78f
avoid changing function attributes from outside
BolinSNLHM Nov 17, 2023
fa39456
Delete python/mat_new.py
BolinSNLHM Dec 12, 2023
b61722d
Update matmul_f32_x86.py
BolinSNLHM Dec 12, 2023
575acaf
Merge branch 'hidet-org:main' into main
BolinSNLHM Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,7 @@ def visit_ForStmt(self, stmt: ForStmt):
doc += NewLine() + '#pragma unroll'
elif stmt.attr.parallel:
if stmt.attr.parallel_threads:
doc += NewLine() + '#pragma omp parallel for schedule(dynamic) num_threads({})'.format(
stmt.attr.parallel_threads
)
doc += NewLine() + '#pragma omp parallel for num_threads({})'.format(stmt.attr.parallel_threads)
else:
doc += NewLine() + '#pragma omp parallel for'
doc += NewLine() + Text('for (') + init_doc + '; ' + cond_doc + '; ' + update_doc + ') '
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
from .batch_matmul import batch_matmul, BatchMatmulOp, BatchMatmulTask
from . import resolve

from .matmul_f32_x86 import matmul_x86

from .matmul_f32_x86 import MatmulF32Taskx86, Matmulx86Op
from .matmul_f32_x86 import Matmulx86Op, MatmulF32Taskx86
from .matmul_f32_x86 import matmul_x86
965 changes: 709 additions & 256 deletions python/hidet/graph/ops/matmul/matmul_f32_x86.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions python/hidet/ir/primitives/cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,15 @@
from .avx import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store, avx_f32x4_setzero
from .avx import avx_f32x8_broadcast, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_store, avx_f32x8_setzero
from .avx import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc
from .avx import avx_f32x8_store_aligned, avx_f32x8_load_aligned
from .avx import avx_f32x4_store_aligned, avx_f32x4_load_aligned
from .avx import (
avx_f32x8_unpackhi,
avx_f32x8_unpacklo,
avx_f32x8_shuffle,
avx_f32x8_cast_f32x4,
avx_f32x8_insert_f32x4,
avx_f32x8_permute2f32x4,
)

from .atomic import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor
42 changes: 42 additions & 0 deletions python/hidet/ir/primitives/cpu/atomic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

from hidet.ir.expr import Expr
from hidet.ir.type import FuncType, VoidType, PointerType
from hidet.ir.primitives.func import register_primitive_function
from hidet.utils import initialize
from hidet.ir.primitives.func import call_primitive_func


@initialize()
def register_primitive_functions():
functions = [
('cpu_atomic_load_n', '__atomic_load_n', FuncType([PointerType(VoidType()), 'int32'], 'int32')),
('cpu_atomic_add_fetch', '__atomic_add_fetch', FuncType([PointerType(VoidType()), 'int32', 'int32'], 'int32')),
('cpu_atomic_fetch_xor', '__atomic_fetch_xor', FuncType([PointerType(VoidType()), 'int32', 'int32'], 'int32')),
]

for name, codegen_name, func_type in functions:
register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name)


def cpu_atomic_load_n(ptr: Expr, order: Union[Expr, int]) -> Expr:
return call_primitive_func('cpu_atomic_load_n', [ptr, order])


def cpu_atomic_add_fetch(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr:
return call_primitive_func('cpu_atomic_add_fetch', [ptr, val, order])


def cpu_atomic_fetch_xor(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr:
return call_primitive_func('cpu_atomic_fetch_xor', [ptr, val, order])
66 changes: 66 additions & 0 deletions python/hidet/ir/primitives/cpu/avx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,24 @@ def register_primitive_functions():
('avx_x86_float32x4_broadcast', '_mm_broadcast_ss', FuncType([PointerType('float32')], 'float32x4')),
('avx_x86_float32x4_fmadd', '_mm_fmadd_ps', FuncType(['float32x4', 'float32x4', 'float32x4'], 'float32x4')),
('avx_x86_float32x4_load', '_mm_loadu_ps', FuncType([PointerType('float32')], 'float32x4')),
('avx_x86_float32x4_load_aligned', '_mm_load_ps', FuncType([PointerType('float32')], 'float32x4')),
('avx_x86_float32x4_store', '_mm_storeu_ps', FuncType([PointerType('float32'), 'float32x4'], VoidType())),
(
'avx_x86_float32x4_store_aligned',
'_mm_store_ps',
FuncType([PointerType('float32'), 'float32x4'], VoidType()),
),
('avx_x86_float32x4_setzero', '_mm_setzero_ps', FuncType([], 'float32x4')),
('avx_x86_float32x8_broadcast', '_mm256_broadcast_ss', FuncType([PointerType('float32')], 'float32x8')),
('avx_x86_float32x8_fmadd', '_mm256_fmadd_ps', FuncType(['float32x8', 'float32x8', 'float32x8'], 'float32x8')),
('avx_x86_float32x8_load', '_mm256_loadu_ps', FuncType([PointerType('float32')], 'float32x8')),
('avx_x86_float32x8_load_aligned', '_mm256_load_ps', FuncType([PointerType('float32')], 'float32x8')),
('avx_x86_float32x8_store', '_mm256_storeu_ps', FuncType([PointerType('float32'), 'float32x8'], VoidType())),
(
'avx_x86_float32x8_store_aligned',
'_mm256_store_ps',
FuncType([PointerType('float32'), 'float32x8'], VoidType()),
),
('avx_x86_float32x8_setzero', '_mm256_setzero_ps', FuncType([], 'float32x8')),
('avx_x86_malloc', '_mm_malloc', FuncType(['uint64', 'uint64'], PointerType(VoidType()))),
('avx_x86_free', '_mm_free', FuncType([PointerType(VoidType())], VoidType())),
Expand All @@ -39,6 +51,20 @@ def register_primitive_functions():
'memcpy',
FuncType([PointerType(VoidType()), PointerType(VoidType()), 'uint64'], PointerType(VoidType())),
),
('avx_x86_float32x8_unpacklo', '_mm256_unpacklo_ps', FuncType(['float32x8', 'float32x8'], 'float32x8')),
('avx_x86_float32x8_unpackhi', '_mm256_unpackhi_ps', FuncType(['float32x8', 'float32x8'], 'float32x8')),
('avx_x86_float32x8_shuffle', '_mm256_shuffle_ps', FuncType(['float32x8', 'float32x8', 'int32'], 'float32x8')),
('avx_x86_float32x8_cast_float32x4', '_mm256_castps256_ps128', FuncType(['float32x8'], 'float32x4')),
(
'avx_x86_float32x8_insert_float32x4',
'_mm256_insertf128_ps',
FuncType(['float32x8', 'float32x4', 'int32'], 'float32x8'),
),
(
'avx_x86_float32x8_permute2float32x4',
'_mm256_permute2f128_ps',
FuncType(['float32x8', 'float32x8', 'int32'], 'float32x8'),
),
]
for name, codegen_name, func_type in functions:
register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name)
Expand Down Expand Up @@ -92,13 +118,53 @@ def avx_f32x4_load(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_load', [addr])


def avx_f32x4_load_aligned(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_load_aligned', [addr])


def avx_f32x8_load(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_load', [addr])


def avx_f32x8_load_aligned(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_load_aligned', [addr])


def avx_f32x4_store(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_store', [addr, src])


def avx_f32x4_store_aligned(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_store_aligned', [addr, src])


def avx_f32x8_store(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_store', [addr, src])


def avx_f32x8_store_aligned(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_store_aligned', [addr, src])


def avx_f32x8_unpacklo(a: Expr, b: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_unpacklo', [a, b])


def avx_f32x8_unpackhi(a: Expr, b: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_unpackhi', [a, b])


def avx_f32x8_shuffle(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call:
return call_primitive_func('avx_x86_float32x8_shuffle', [a, b, imm])


def avx_f32x8_cast_f32x4(a: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_cast_float32x4', [a])


def avx_f32x8_insert_f32x4(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call:
return call_primitive_func('avx_x86_float32x8_insert_float32x4', [a, b, imm])


def avx_f32x8_permute2f32x4(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call:
return call_primitive_func('avx_x86_float32x8_permute2float32x4', [a, b, imm])
20 changes: 20 additions & 0 deletions python/hidet/lang/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,23 @@
avx_f32x8_setzero,
)
from hidet.ir.primitives.cpu import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc

# from hidet.ir.primitives.cpu import openmp_get_thread_num, openmp_get_num_threads

from hidet.ir.primitives.cpu import (
avx_f32x8_store_aligned,
avx_f32x8_load_aligned,
avx_f32x4_store_aligned,
avx_f32x4_load_aligned,
)

from hidet.ir.primitives.cpu import (
avx_f32x8_unpackhi,
avx_f32x8_unpacklo,
avx_f32x8_shuffle,
avx_f32x8_cast_f32x4,
avx_f32x8_insert_f32x4,
avx_f32x8_permute2f32x4,
)

from hidet.ir.primitives.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor
148 changes: 148 additions & 0 deletions python/mat_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this file.

import pytest

import hidet
from hidet.graph.ops import matmul_x86
from hidet.testing import check_binary
from hidet.option import debug_cache_tuning

import torch

import tvm
from tvm import te, auto_scheduler

@auto_scheduler.register_workload
def matmul_ansor(M, K, N, dtype):
A = te.placeholder((M, K), name='A', dtype=dtype)
B = te.placeholder((K, N), name='B', dtype=dtype)

k = te.reduce_axis((0, K), name='k')
rst = te.compute(
(M, N),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name='matmul_ansor',
attrs={"layout_free_placeholders": [B],
# Enable automatic layout transform for B}
}
)

return [A, B, rst]
hidet.option.cache_dir("./wtf")

target = tvm.target.Target("llvm -mcpu=core-avx2")
debug_cache_tuning(True)
hidet.option.search_space(0)

np.random.seed(42)
# for m, n, k in [(33, 65, 60), (32, 92, 128)]:
# for m, n, k in [(7, 1, 17), (256, 256, 256), (512, 512, 512), (768, 768, 768)]:
# for m, n, k in [(7, 1, 17), (32, 32, 32), (36, 36, 36), (37, 37, 37)]:
# for m, n, k in [(7, 17, 1), (16, 16, 16), (333, 444, 555), (768, 768, 768)]:
# for m, n, k in [(7, 17, 1), (16, 16, 16), (17, 17, 17), (36, 36, 36), (37, 37, 37), (128, 128, 128), (256, 256, 256), (333, 444, 555), (768, 768, 768)]:
# for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768)]:
for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768), (555, 256, 3072), (2048, 2048, 2048)]:
# a = hidet.randn([m, k], device='cpu')
# b = hidet.randn([k, n], device='cpu')

# a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu')
# b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu')
# #
# # print(f"a_torch: {a_torch}")
# # print(f"b_torch: {b_torch}")
#
# a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu')
# b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu')
# print(f"a: {a}")
# print(f"b: {b}")

a = hidet.randn([m, k], device='cpu')
b = hidet.randn([k, n], device='cpu')
# a = hidet.ones([m, k], device='cpu')
# b = hidet.ones([k, n], device='cpu')
#

x1 = hidet.symbol_like(a)
x2 = hidet.symbol_like(b)
y = matmul_x86(x1, x2)
graph = hidet.trace_from(
y, inputs=[x1, x2]
)
opt_graph = hidet.graph.optimize(graph)
compiled_func = opt_graph.nodes[0].compiled_task
c = compiled_func(a, b)

actual = c.numpy()
desired = a.numpy() @ b.numpy()

fails = 0

for i in range(m):
for j in range(n):
if abs(actual[i, j] - desired[i, j]) < 1e-3:
# print(f"Actually passed for i={i}, j={j}")
continue
else:
print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}")
fails += 1

print(f"Total fails: {fails}")

# for i in range(m):
# for j in range(n):
# if actual[i, j] == 0.0:
# print(f"element is 0 for i={i}, j={j}")


np.testing.assert_allclose(
actual=actual,
desired=desired,
rtol=1e-3,
atol=1e-3
)

print("passed for m={}, n={}, k={}".format(m, n, k))

# hidet_latency = hidet.utils.benchmark_func(
# lambda: compiled_func(a, b), repeat=50
# )
# np_latency = hidet.utils.benchmark_func(
# lambda: a.numpy() @ b.numpy(), repeat=50
# )
#
# ansor_task = tvm.auto_scheduler.SearchTask(
# func=matmul_ansor, args=(m, k, n, "float32"), target=target
# )
# log_file = f"matmul_{m}x{k}x{n}.json"
# tune_option = auto_scheduler.TuningOptions(
# num_measure_trials=1000,
# measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
# verbose=2,
# )
#
# ansor_task.tune(tune_option)
# sch, args = ansor_task.apply_best(log_file)
# with open(f"./matmul_TIR_{m}x{k}x{n}", 'w') as f:
# f.write(str(tvm.lower(sch, args, simple_mode=True)))
# ansor_func = tvm.build(sch, args, target)
# dev = tvm.cpu()
# a_tvm = tvm.nd.array(a.numpy(), device=dev)
# b_tvm = tvm.nd.array(b.numpy(), device=dev)
# c_tvm = tvm.nd.empty((m, n), device=dev)
#
# ansor_func(a_tvm, b_tvm, c_tvm)
#
# np.testing.assert_allclose(
# actual=c_tvm.numpy(),
# desired=a_tvm.numpy() @ b_tvm.numpy(),
# rtol=1e-3,
# atol=1e-3
# )
#
# ansor_latency = hidet.utils.benchmark_func(
# lambda: ansor_func(a_tvm, b_tvm, c_tvm), repeat=30
# )
#
# with open(f"./perf_{m}x{k}x{n}.txt", 'w') as f:
# f.write(f"m={m}, k={k}, n={n}: hidet takes {hidet_latency:.2f} ms\n")
# f.write(f"m={m}, k={k}, n={n}: numpy takes {np_latency: .2f} ms\n")
# f.write(f"m={m}, k={k}, n={n}: ansor takes {ansor_latency: .2f} ms\n")
2 changes: 1 addition & 1 deletion tests/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from hidet.testing import check_binary, check_binary_dynamic


@pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.")
# @pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.")
@pytest.mark.parametrize("a_shape, b_shape", [[[333, 444], [444, 555]], [[133, 1], [1, 177]]])
def test_matmul_x86(a_shape, b_shape):
# TODO: Doesn't support broadcasting yet; need to add it later?
Expand Down