Skip to content

Commit

Permalink
[Operators] Improving fp32 matrix multiplication on x86 CPUs (#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
BolinSNLHM committed Dec 14, 2023
1 parent f3fa023 commit 264beec
Show file tree
Hide file tree
Showing 8 changed files with 852 additions and 261 deletions.
4 changes: 1 addition & 3 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,7 @@ def visit_ForStmt(self, stmt: ForStmt):
doc += NewLine() + '#pragma unroll'
elif stmt.attr.parallel:
if stmt.attr.parallel_threads:
doc += NewLine() + '#pragma omp parallel for schedule(dynamic) num_threads({})'.format(
stmt.attr.parallel_threads
)
doc += NewLine() + '#pragma omp parallel for num_threads({})'.format(stmt.attr.parallel_threads)
else:
doc += NewLine() + '#pragma omp parallel for'
doc += NewLine() + Text('for (') + init_doc + '; ' + cond_doc + '; ' + update_doc + ') '
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
from .batch_matmul import batch_matmul, BatchMatmulOp, BatchMatmulTask
from . import resolve

from .matmul_f32_x86 import matmul_x86

from .matmul_f32_x86 import MatmulF32Taskx86, Matmulx86Op
from .matmul_f32_x86 import Matmulx86Op, MatmulF32Taskx86
from .matmul_f32_x86 import matmul_x86
963 changes: 708 additions & 255 deletions python/hidet/graph/ops/matmul/matmul_f32_x86.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions python/hidet/ir/primitives/cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,15 @@
from .avx import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store, avx_f32x4_setzero
from .avx import avx_f32x8_broadcast, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_store, avx_f32x8_setzero
from .avx import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc
from .avx import avx_f32x8_store_aligned, avx_f32x8_load_aligned
from .avx import avx_f32x4_store_aligned, avx_f32x4_load_aligned
from .avx import (
avx_f32x8_unpackhi,
avx_f32x8_unpacklo,
avx_f32x8_shuffle,
avx_f32x8_cast_f32x4,
avx_f32x8_insert_f32x4,
avx_f32x8_permute2f32x4,
)

from .atomic import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor
42 changes: 42 additions & 0 deletions python/hidet/ir/primitives/cpu/atomic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

from hidet.ir.expr import Expr
from hidet.ir.type import FuncType, VoidType, PointerType
from hidet.ir.primitives.func import register_primitive_function
from hidet.utils import initialize
from hidet.ir.primitives.func import call_primitive_func


@initialize()
def register_primitive_functions():
functions = [
('cpu_atomic_load_n', '__atomic_load_n', FuncType([PointerType(VoidType()), 'int32'], 'int32')),
('cpu_atomic_add_fetch', '__atomic_add_fetch', FuncType([PointerType(VoidType()), 'int32', 'int32'], 'int32')),
('cpu_atomic_fetch_xor', '__atomic_fetch_xor', FuncType([PointerType(VoidType()), 'int32', 'int32'], 'int32')),
]

for name, codegen_name, func_type in functions:
register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name)


def cpu_atomic_load_n(ptr: Expr, order: Union[Expr, int]) -> Expr:
return call_primitive_func('cpu_atomic_load_n', [ptr, order])


def cpu_atomic_add_fetch(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr:
return call_primitive_func('cpu_atomic_add_fetch', [ptr, val, order])


def cpu_atomic_fetch_xor(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr:
return call_primitive_func('cpu_atomic_fetch_xor', [ptr, val, order])
66 changes: 66 additions & 0 deletions python/hidet/ir/primitives/cpu/avx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,24 @@ def register_primitive_functions():
('avx_x86_float32x4_broadcast', '_mm_broadcast_ss', FuncType([PointerType('float32')], 'float32x4')),
('avx_x86_float32x4_fmadd', '_mm_fmadd_ps', FuncType(['float32x4', 'float32x4', 'float32x4'], 'float32x4')),
('avx_x86_float32x4_load', '_mm_loadu_ps', FuncType([PointerType('float32')], 'float32x4')),
('avx_x86_float32x4_load_aligned', '_mm_load_ps', FuncType([PointerType('float32')], 'float32x4')),
('avx_x86_float32x4_store', '_mm_storeu_ps', FuncType([PointerType('float32'), 'float32x4'], VoidType())),
(
'avx_x86_float32x4_store_aligned',
'_mm_store_ps',
FuncType([PointerType('float32'), 'float32x4'], VoidType()),
),
('avx_x86_float32x4_setzero', '_mm_setzero_ps', FuncType([], 'float32x4')),
('avx_x86_float32x8_broadcast', '_mm256_broadcast_ss', FuncType([PointerType('float32')], 'float32x8')),
('avx_x86_float32x8_fmadd', '_mm256_fmadd_ps', FuncType(['float32x8', 'float32x8', 'float32x8'], 'float32x8')),
('avx_x86_float32x8_load', '_mm256_loadu_ps', FuncType([PointerType('float32')], 'float32x8')),
('avx_x86_float32x8_load_aligned', '_mm256_load_ps', FuncType([PointerType('float32')], 'float32x8')),
('avx_x86_float32x8_store', '_mm256_storeu_ps', FuncType([PointerType('float32'), 'float32x8'], VoidType())),
(
'avx_x86_float32x8_store_aligned',
'_mm256_store_ps',
FuncType([PointerType('float32'), 'float32x8'], VoidType()),
),
('avx_x86_float32x8_setzero', '_mm256_setzero_ps', FuncType([], 'float32x8')),
('avx_x86_malloc', '_mm_malloc', FuncType(['uint64', 'uint64'], PointerType(VoidType()))),
('avx_x86_free', '_mm_free', FuncType([PointerType(VoidType())], VoidType())),
Expand All @@ -39,6 +51,20 @@ def register_primitive_functions():
'memcpy',
FuncType([PointerType(VoidType()), PointerType(VoidType()), 'uint64'], PointerType(VoidType())),
),
('avx_x86_float32x8_unpacklo', '_mm256_unpacklo_ps', FuncType(['float32x8', 'float32x8'], 'float32x8')),
('avx_x86_float32x8_unpackhi', '_mm256_unpackhi_ps', FuncType(['float32x8', 'float32x8'], 'float32x8')),
('avx_x86_float32x8_shuffle', '_mm256_shuffle_ps', FuncType(['float32x8', 'float32x8', 'int32'], 'float32x8')),
('avx_x86_float32x8_cast_float32x4', '_mm256_castps256_ps128', FuncType(['float32x8'], 'float32x4')),
(
'avx_x86_float32x8_insert_float32x4',
'_mm256_insertf128_ps',
FuncType(['float32x8', 'float32x4', 'int32'], 'float32x8'),
),
(
'avx_x86_float32x8_permute2float32x4',
'_mm256_permute2f128_ps',
FuncType(['float32x8', 'float32x8', 'int32'], 'float32x8'),
),
]
for name, codegen_name, func_type in functions:
register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name)
Expand Down Expand Up @@ -92,13 +118,53 @@ def avx_f32x4_load(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_load', [addr])


def avx_f32x4_load_aligned(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_load_aligned', [addr])


def avx_f32x8_load(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_load', [addr])


def avx_f32x8_load_aligned(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_load_aligned', [addr])


def avx_f32x4_store(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_store', [addr, src])


def avx_f32x4_store_aligned(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_store_aligned', [addr, src])


def avx_f32x8_store(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_store', [addr, src])


def avx_f32x8_store_aligned(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_store_aligned', [addr, src])


def avx_f32x8_unpacklo(a: Expr, b: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_unpacklo', [a, b])


def avx_f32x8_unpackhi(a: Expr, b: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_unpackhi', [a, b])


def avx_f32x8_shuffle(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call:
return call_primitive_func('avx_x86_float32x8_shuffle', [a, b, imm])


def avx_f32x8_cast_f32x4(a: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_cast_float32x4', [a])


def avx_f32x8_insert_f32x4(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call:
return call_primitive_func('avx_x86_float32x8_insert_float32x4', [a, b, imm])


def avx_f32x8_permute2f32x4(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call:
return call_primitive_func('avx_x86_float32x8_permute2float32x4', [a, b, imm])
20 changes: 20 additions & 0 deletions python/hidet/lang/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,23 @@
avx_f32x8_setzero,
)
from hidet.ir.primitives.cpu import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc

# from hidet.ir.primitives.cpu import openmp_get_thread_num, openmp_get_num_threads

from hidet.ir.primitives.cpu import (
avx_f32x8_store_aligned,
avx_f32x8_load_aligned,
avx_f32x4_store_aligned,
avx_f32x4_load_aligned,
)

from hidet.ir.primitives.cpu import (
avx_f32x8_unpackhi,
avx_f32x8_unpacklo,
avx_f32x8_shuffle,
avx_f32x8_cast_f32x4,
avx_f32x8_insert_f32x4,
avx_f32x8_permute2f32x4,
)

from hidet.ir.primitives.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor
2 changes: 1 addition & 1 deletion tests/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from hidet.testing import check_binary, check_binary_dynamic


@pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.")
# @pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.")
@pytest.mark.parametrize("a_shape, b_shape", [[[333, 444], [444, 555]], [[133, 1], [1, 177]]])
def test_matmul_x86(a_shape, b_shape):
# TODO: Doesn't support broadcasting yet; need to add it later?
Expand Down

0 comments on commit 264beec

Please sign in to comment.