Skip to content

Commit

Permalink
cpu: gemm: remove a templated function
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Mar 22, 2024
1 parent 406a079 commit 6f5621a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 23 deletions.
28 changes: 13 additions & 15 deletions src/common/gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2021-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -116,10 +116,9 @@ dnnl_status_t dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, dim_t M,
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
status_t status = dnnl_success;
MAYBE_VERBOSE(status, "u8", "s8", "s32",
MAYBE_RUN_STACK_CHECKER(dnnl_gemm_u8s8s32,
cpu::gemm_s8x8s32<uint8_t>, &transb, &transa,
c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
&lda, &ao, &beta, C, &ldc, co));
MAYBE_RUN_STACK_CHECKER(dnnl_gemm_u8s8s32, cpu::gemm_s8x8s32,
&transb, &transa, c2f_offsetC(&offsetc), &N, &M, &K, &alpha,
B, &ldb, &bo, A, &lda, &ao, &beta, C, &ldc, co));
return status;
#else
return dnnl::impl::status::unimplemented;
Expand All @@ -133,10 +132,9 @@ dnnl_status_t dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, dim_t M,
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
status_t status = dnnl_success;
MAYBE_VERBOSE(status, "s8", "s8", "s32",
MAYBE_RUN_STACK_CHECKER(dnnl_gemm_s8s8s32,
cpu::gemm_s8x8s32<int8_t>, &transb, &transa,
c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
&lda, &ao, &beta, C, &ldc, co));
MAYBE_RUN_STACK_CHECKER(dnnl_gemm_s8s8s32, cpu::gemm_s8x8s32,
&transb, &transa, c2f_offsetC(&offsetc), &N, &M, &K, &alpha,
B, &ldb, &bo, A, &lda, &ao, &beta, C, &ldc, co));
return status;
#else
return dnnl::impl::status::unimplemented;
Expand Down Expand Up @@ -184,9 +182,9 @@ dnnl_status_t dnnl_threadpool_interop_gemm_u8s8s32(char transa, char transb,
status_t status = dnnl_success;
MAYBE_VERBOSE(status, "u8", "s8", "s32",
MAYBE_RUN_STACK_CHECKER(dnnl_threadpool_interop_gemm_u8s8s32,
cpu::gemm_s8x8s32<uint8_t>, &transb, &transa,
c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
&lda, &ao, &beta, C, &ldc, co));
cpu::gemm_s8x8s32, &transb, &transa, c2f_offsetC(&offsetc),
&N, &M, &K, &alpha, B, &ldb, &bo, A, &lda, &ao, &beta, C,
&ldc, co));
threadpool_utils::deactivate_threadpool();
return status;
}
Expand All @@ -200,9 +198,9 @@ dnnl_status_t dnnl_threadpool_interop_gemm_s8s8s32(char transa, char transb,
status_t status = dnnl_success;
MAYBE_VERBOSE(status, "s8", "s8", "s32",
MAYBE_RUN_STACK_CHECKER(dnnl_threadpool_interop_gemm_s8s8s32,
cpu::gemm_s8x8s32<int8_t>, &transb, &transa,
c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
&lda, &ao, &beta, C, &ldc, co));
cpu::gemm_s8x8s32, &transb, &transa, c2f_offsetC(&offsetc),
&N, &M, &K, &alpha, B, &ldb, &bo, A, &lda, &ao, &beta, C,
&ldc, co));
threadpool_utils::deactivate_threadpool();
return status;
}
Expand Down
4 changes: 1 addition & 3 deletions src/cpu/gemm/gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2024 Intel Corporation
* Copyright 2022 IBM Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -185,7 +185,6 @@ dnnl_status_t try_cblas_gemm_s8u8s32(const char *transa, const char *transb,
#endif
}

template <>
dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao,
Expand Down Expand Up @@ -227,7 +226,6 @@ dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
B, LDB, bo, beta, C, LDC, co);
}

template <>
dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao,
Expand Down
15 changes: 10 additions & 5 deletions src/cpu/gemm/gemm.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2024 Intel Corporation
* Copyright 2022 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -57,11 +57,16 @@ dnnl_status_t extended_sgemm(const char *transa, const char *transb,
const float *beta, float *C, const dim_t *ldc,
const float *bias = nullptr, bool force_jit_gemm = false);

template <typename b_dt>
dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
const float *alpha, const int8_t *A, const dim_t *lda, const int8_t *ao,
const b_dt *B, const dim_t *ldb, const b_dt *bo, const float *beta,
const char *offsetc, const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const dim_t *lda, const int8_t *ao,
const uint8_t *b, const dim_t *ldb, const uint8_t *bo,
const float *beta, int32_t *c, const dim_t *ldc, const int32_t *co);

dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
const char *offsetc, const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const dim_t *lda, const int8_t *ao,
const int8_t *b, const dim_t *ldb, const int8_t *bo, const float *beta,
int32_t *c, const dim_t *ldc, const int32_t *co);

dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb,
Expand Down

0 comments on commit 6f5621a

Please sign in to comment.