-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operators] Improving fp32 matrix multiplication on x86 CPUs #378
Merged
Merged
Changes from 138 commits
Commits
Show all changes
141 commits
Select commit
Hold shift + click to select a range
efe3e14
.
BolinSNLHM b19a212
Merge branch 'hidet-org:main' into main
BolinSNLHM d7e4043
.
BolinSNLHM e13af0a
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM a7bce75
added basic openMP primitives
BolinSNLHM bad483c
Merge branch 'main' into omp
BolinSNLHM d7f6469
added those primitives back
BolinSNLHM f211a48
let me pretend like it's all good for tonight
BolinSNLHM bbb5afc
...
BolinSNLHM 569fb49
working on refactoring
BolinSNLHM b32ea73
ready to be tested on the eco server
BolinSNLHM dbbb2b6
fix stupid error
BolinSNLHM 014f5c1
..
BolinSNLHM 2d82325
fix more error
BolinSNLHM 11c9e70
..
BolinSNLHM 4586e89
fixing hidet script error
BolinSNLHM 65c3b9d
...:
BolinSNLHM 286c107
....
BolinSNLHM bfacaf8
...
BolinSNLHM 8246466
..
BolinSNLHM 7518042
..
BolinSNLHM f8a97b2
fixing strange error
BolinSNLHM 1a87c27
more errors
BolinSNLHM 3104473
more err
BolinSNLHM 68bc03d
...
BolinSNLHM 9059ca3
...
BolinSNLHM df5a177
global
BolinSNLHM 27da1ba
global var
BolinSNLHM fca3694
.
BolinSNLHM 14973b4
.
BolinSNLHM 45ad16a
...
BolinSNLHM 36b3c52
..:
BolinSNLHM c79fcca
cast
BolinSNLHM 87cdd76
cast
BolinSNLHM 8648ced
...
BolinSNLHM 075cc64
.
BolinSNLHM 7814d6d
now segfault not internal errors
BolinSNLHM ff058bf
stupid error
BolinSNLHM f9f3b81
err
BolinSNLHM 0a7b2fe
...
BolinSNLHM 99954e1
..
BolinSNLHM b884a95
..
BolinSNLHM 12a139a
.
BolinSNLHM 8cf009d
.
BolinSNLHM 717069f
...
BolinSNLHM 7b53554
.
BolinSNLHM f933711
small fix
BolinSNLHM 42054a4
..
BolinSNLHM 60599c2
..
BolinSNLHM 2d65005
.
BolinSNLHM 747508b
.
BolinSNLHM 4e5c7da
.
BolinSNLHM 23f2768
try single thread first
BolinSNLHM 0ab4888
..
BolinSNLHM 1631d77
dumb mistake again
BolinSNLHM 62c075c
..
BolinSNLHM 5d4a314
..
BolinSNLHM e30ab31
keep debugging
BolinSNLHM 134a1d5
..
BolinSNLHM e1e2d29
..
BolinSNLHM 7a7ff5e
.
BolinSNLHM 29de46f
..
BolinSNLHM ca9e67d
...
BolinSNLHM 43d4a60
..:
BolinSNLHM 3d67673
.
BolinSNLHM 3c9d792
..
BolinSNLHM 6782047
.
BolinSNLHM 9401c1e
..
BolinSNLHM e655035
..
BolinSNLHM 4c7ed70
..
BolinSNLHM 21978bb
..
BolinSNLHM c90991f
..
BolinSNLHM 7c3ef0a
continue fixing
BolinSNLHM 4acf6c0
..
BolinSNLHM c740a3a
.
BolinSNLHM 8f0ee0e
...
BolinSNLHM 01e84ec
...
BolinSNLHM 90505e7
..
BolinSNLHM 805959e
...
BolinSNLHM 8bb52d3
..
BolinSNLHM 94abfa7
..
BolinSNLHM a3f35dc
.
BolinSNLHM 230e6d0
..
BolinSNLHM e3bf60a
..
BolinSNLHM e5e4466
.
BolinSNLHM 601e6b2
.
BolinSNLHM 2df7355
..
BolinSNLHM ee30078
bruh
BolinSNLHM cb54a7e
..
BolinSNLHM 8e07dad
.
BolinSNLHM 8723df6
.
BolinSNLHM 0919d12
..
BolinSNLHM b2a6c15
..
BolinSNLHM 43922bb
..
BolinSNLHM 553dfc4
..
BolinSNLHM ae29fb3
...
BolinSNLHM 0572ace
.
BolinSNLHM ce1f5fd
.
BolinSNLHM aaa500c
..
BolinSNLHM 6445811
.
BolinSNLHM d3e1a1d
.
BolinSNLHM 6589848
..
BolinSNLHM 4bc93c8
.
BolinSNLHM 563b121
.
BolinSNLHM 17011a1
.
BolinSNLHM 18f8b53
..
BolinSNLHM 12e44c2
..
BolinSNLHM ceb22dd
..
BolinSNLHM 68fbba8
..
BolinSNLHM 0c3639f
.
BolinSNLHM 76d55a1
..
BolinSNLHM 9e289e4
..
BolinSNLHM 165c3d5
..
BolinSNLHM e898772
..
BolinSNLHM 4cb35cb
.
BolinSNLHM 6ba8075
..
BolinSNLHM 073266a
.
BolinSNLHM d736d96
..
BolinSNLHM 83118f3
.
BolinSNLHM df1cc83
....
BolinSNLHM a85e56f
..
BolinSNLHM 728ec9a
kept debugging the matrix mul kernel
BolinSNLHM dfdf084
bruh
BolinSNLHM d2e1ab4
fixed a dumb bug that got me stuck for way too much longer than neces…
BolinSNLHM 0c0efe0
.
BolinSNLHM 1bd2cfe
remove prints
BolinSNLHM 6721ed2
.
BolinSNLHM 442fbd2
..
BolinSNLHM b4e00e9
logic error fix in packing of A
BolinSNLHM ad9c453
seems like still bugs, but they disappear with print...
BolinSNLHM d34f031
fix bug caused by static local vairable
BolinSNLHM 954da89
...
BolinSNLHM 78d09c4
fix alignment
BolinSNLHM 838a61e
cleanup
BolinSNLHM 6f572a4
Merge branch 'fix-zero-init' into main
BolinSNLHM 3fbb635
ready for PR
BolinSNLHM 656bbd0
......
BolinSNLHM ebcc78f
avoid changing function attributes from outside
BolinSNLHM fa39456
Delete python/mat_new.py
BolinSNLHM b61722d
Update matmul_f32_x86.py
BolinSNLHM 575acaf
Merge branch 'hidet-org:main' into main
BolinSNLHM File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Union | ||
|
||
from hidet.ir.expr import Expr | ||
from hidet.ir.type import FuncType, VoidType, PointerType | ||
from hidet.ir.primitives.func import register_primitive_function | ||
from hidet.utils import initialize | ||
from hidet.ir.primitives.func import call_primitive_func | ||
|
||
|
||
@initialize() | ||
def register_primitive_functions(): | ||
functions = [ | ||
('cpu_atomic_load_n', '__atomic_load_n', FuncType([PointerType(VoidType()), 'int32'], 'int32')), | ||
('cpu_atomic_add_fetch', '__atomic_add_fetch', FuncType([PointerType(VoidType()), 'int32', 'int32'], 'int32')), | ||
('cpu_atomic_fetch_xor', '__atomic_fetch_xor', FuncType([PointerType(VoidType()), 'int32', 'int32'], 'int32')), | ||
] | ||
|
||
for name, codegen_name, func_type in functions: | ||
register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name) | ||
|
||
|
||
def cpu_atomic_load_n(ptr: Expr, order: Union[Expr, int]) -> Expr: | ||
return call_primitive_func('cpu_atomic_load_n', [ptr, order]) | ||
|
||
|
||
def cpu_atomic_add_fetch(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr: | ||
return call_primitive_func('cpu_atomic_add_fetch', [ptr, val, order]) | ||
|
||
|
||
def cpu_atomic_fetch_xor(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr: | ||
return call_primitive_func('cpu_atomic_fetch_xor', [ptr, val, order]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
import hidet | ||
from hidet.graph.ops import matmul_x86 | ||
from hidet.testing import check_binary | ||
from hidet.option import debug_cache_tuning | ||
|
||
import torch | ||
|
||
import tvm | ||
from tvm import te, auto_scheduler | ||
|
||
@auto_scheduler.register_workload | ||
def matmul_ansor(M, K, N, dtype): | ||
A = te.placeholder((M, K), name='A', dtype=dtype) | ||
B = te.placeholder((K, N), name='B', dtype=dtype) | ||
|
||
k = te.reduce_axis((0, K), name='k') | ||
rst = te.compute( | ||
(M, N), | ||
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), | ||
name='matmul_ansor', | ||
attrs={"layout_free_placeholders": [B], | ||
# Enable automatic layout transform for B} | ||
} | ||
) | ||
|
||
return [A, B, rst] | ||
hidet.option.cache_dir("./wtf") | ||
|
||
target = tvm.target.Target("llvm -mcpu=core-avx2") | ||
debug_cache_tuning(True) | ||
hidet.option.search_space(0) | ||
|
||
np.random.seed(42) | ||
# for m, n, k in [(33, 65, 60), (32, 92, 128)]: | ||
# for m, n, k in [(7, 1, 17), (256, 256, 256), (512, 512, 512), (768, 768, 768)]: | ||
# for m, n, k in [(7, 1, 17), (32, 32, 32), (36, 36, 36), (37, 37, 37)]: | ||
# for m, n, k in [(7, 17, 1), (16, 16, 16), (333, 444, 555), (768, 768, 768)]: | ||
# for m, n, k in [(7, 17, 1), (16, 16, 16), (17, 17, 17), (36, 36, 36), (37, 37, 37), (128, 128, 128), (256, 256, 256), (333, 444, 555), (768, 768, 768)]: | ||
# for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768)]: | ||
for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768), (555, 256, 3072), (2048, 2048, 2048)]: | ||
# a = hidet.randn([m, k], device='cpu') | ||
# b = hidet.randn([k, n], device='cpu') | ||
|
||
# a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') | ||
# b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') | ||
# # | ||
# # print(f"a_torch: {a_torch}") | ||
# # print(f"b_torch: {b_torch}") | ||
# | ||
# a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') | ||
# b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') | ||
# print(f"a: {a}") | ||
# print(f"b: {b}") | ||
|
||
a = hidet.randn([m, k], device='cpu') | ||
b = hidet.randn([k, n], device='cpu') | ||
# a = hidet.ones([m, k], device='cpu') | ||
# b = hidet.ones([k, n], device='cpu') | ||
# | ||
|
||
x1 = hidet.symbol_like(a) | ||
x2 = hidet.symbol_like(b) | ||
y = matmul_x86(x1, x2) | ||
graph = hidet.trace_from( | ||
y, inputs=[x1, x2] | ||
) | ||
opt_graph = hidet.graph.optimize(graph) | ||
compiled_func = opt_graph.nodes[0].compiled_task | ||
c = compiled_func(a, b) | ||
|
||
actual = c.numpy() | ||
desired = a.numpy() @ b.numpy() | ||
|
||
fails = 0 | ||
|
||
for i in range(m): | ||
for j in range(n): | ||
if abs(actual[i, j] - desired[i, j]) < 1e-3: | ||
# print(f"Actually passed for i={i}, j={j}") | ||
continue | ||
else: | ||
print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") | ||
fails += 1 | ||
|
||
print(f"Total fails: {fails}") | ||
|
||
# for i in range(m): | ||
# for j in range(n): | ||
# if actual[i, j] == 0.0: | ||
# print(f"element is 0 for i={i}, j={j}") | ||
|
||
|
||
np.testing.assert_allclose( | ||
actual=actual, | ||
desired=desired, | ||
rtol=1e-3, | ||
atol=1e-3 | ||
) | ||
|
||
print("passed for m={}, n={}, k={}".format(m, n, k)) | ||
|
||
# hidet_latency = hidet.utils.benchmark_func( | ||
# lambda: compiled_func(a, b), repeat=50 | ||
# ) | ||
# np_latency = hidet.utils.benchmark_func( | ||
# lambda: a.numpy() @ b.numpy(), repeat=50 | ||
# ) | ||
# | ||
# ansor_task = tvm.auto_scheduler.SearchTask( | ||
# func=matmul_ansor, args=(m, k, n, "float32"), target=target | ||
# ) | ||
# log_file = f"matmul_{m}x{k}x{n}.json" | ||
# tune_option = auto_scheduler.TuningOptions( | ||
# num_measure_trials=1000, | ||
# measure_callbacks=[auto_scheduler.RecordToFile(log_file)], | ||
# verbose=2, | ||
# ) | ||
# | ||
# ansor_task.tune(tune_option) | ||
# sch, args = ansor_task.apply_best(log_file) | ||
# with open(f"./matmul_TIR_{m}x{k}x{n}", 'w') as f: | ||
# f.write(str(tvm.lower(sch, args, simple_mode=True))) | ||
# ansor_func = tvm.build(sch, args, target) | ||
# dev = tvm.cpu() | ||
# a_tvm = tvm.nd.array(a.numpy(), device=dev) | ||
# b_tvm = tvm.nd.array(b.numpy(), device=dev) | ||
# c_tvm = tvm.nd.empty((m, n), device=dev) | ||
# | ||
# ansor_func(a_tvm, b_tvm, c_tvm) | ||
# | ||
# np.testing.assert_allclose( | ||
# actual=c_tvm.numpy(), | ||
# desired=a_tvm.numpy() @ b_tvm.numpy(), | ||
# rtol=1e-3, | ||
# atol=1e-3 | ||
# ) | ||
# | ||
# ansor_latency = hidet.utils.benchmark_func( | ||
# lambda: ansor_func(a_tvm, b_tvm, c_tvm), repeat=30 | ||
# ) | ||
# | ||
# with open(f"./perf_{m}x{k}x{n}.txt", 'w') as f: | ||
# f.write(f"m={m}, k={k}, n={n}: hidet takes {hidet_latency:.2f} ms\n") | ||
# f.write(f"m={m}, k={k}, n={n}: numpy takes {np_latency: .2f} ms\n") | ||
# f.write(f"m={m}, k={k}, n={n}: ansor takes {ansor_latency: .2f} ms\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this file.