Skip to content

Commit

Permalink
add simd optimizations and fix for loongarch (#3649)
Browse files Browse the repository at this point in the history
* Add SIMD for loongarch

1. Add VAACalcSadBgd_lasx.
2. Refine WelsSampleSad8x8x2_lasx & WelsIDctT4Rec_lasx & WelsIDctFourT4Rec_lasx.

* Fix ninja build warning and remove unneeded head files

===>
../codec/encoder/core/loongarch/sample_lasx.c: In function ‘WelsIntra8x8Combined3Sad_lasx’:
../codec/encoder/core/loongarch/sample_lasx.c:62:14: warning: implicit declaration of function
‘WelsSampleSad8x8_lasx’; did you mean ‘WelsSampleSad8x8_c’? [-Wimplicit-function-declaration]
   iCurCost = WelsSampleSad8x8_lasx(pDstChroma, 8, pEncCb, iEncStride);
              ^~~~~~~~~~~~~~~~~~~~~
              WelsSampleSad8x8_c
<===

(1) WelsSampleSad8x8_lasx called in .c file but wrapped by #if defined(HAVE_LASX) in sad_common.h,
    so need to add HAVE_LASX define for .c files.
(2) lsx/lasx simd codes are in .c files, so there is no need to add -mlsx/-mlasx for .cpp files.
  • Loading branch information
jinboson committed Apr 20, 2023
1 parent 6967c09 commit 9866066
Show file tree
Hide file tree
Showing 11 changed files with 458 additions and 72 deletions.
29 changes: 15 additions & 14 deletions codec/common/loongarch/satd_sad_lasx.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ int32_t WelsSampleSad8x8x2_lasx (uint8_t* pSample1, int32_t iStride1,
src1_4, src1_5, src1_6, src1_7;
__m256i src2_0, src2_1, src2_2, src2_3,
src2_4, src2_5, src2_6, src2_7;

DUP4_ARG2(__lasx_xvldx,
pSrc1, iStride0,
pSrc1, iStride1,
Expand All @@ -136,7 +135,6 @@ int32_t WelsSampleSad8x8x2_lasx (uint8_t* pSample1, int32_t iStride1,
pSrc2, iStride2_tmp6,
pSrc2, iStride2_tmp7,
src2_4, src2_5, src2_6, src2_7);

DUP4_ARG3(__lasx_xvpermi_q,
src1_0, src1_1, 0x20,
src1_2, src1_3, 0x20,
Expand All @@ -149,19 +147,22 @@ int32_t WelsSampleSad8x8x2_lasx (uint8_t* pSample1, int32_t iStride1,
src2_4, src2_5, 0x20,
src2_6, src2_7, 0x20,
src2_0, src2_2, src2_4, src2_6);

HORISUM(src1_0, src2_0, src1_0);
HORISUM(src1_2, src2_2, src1_2);
HORISUM(src1_4, src2_4, src1_4);
HORISUM(src1_6, src2_6, src1_6);

src1_0 = __lasx_xvadd_d(src1_0, src1_2);
src1_0 = __lasx_xvadd_d(src1_0, src1_4);
src1_0 = __lasx_xvadd_d(src1_0, src1_6);
src1_0 = __lasx_xvabsd_bu(src1_0, src2_0);
src1_2 = __lasx_xvabsd_bu(src1_2, src2_2);
src1_4 = __lasx_xvabsd_bu(src1_4, src2_4);
src1_6 = __lasx_xvabsd_bu(src1_6, src2_6);
src1_0 = __lasx_xvhaddw_hu_bu(src1_0, src1_0);
src1_2 = __lasx_xvhaddw_hu_bu(src1_2, src1_2);
src1_4 = __lasx_xvhaddw_hu_bu(src1_4, src1_4);
src1_6 = __lasx_xvhaddw_hu_bu(src1_6, src1_6);
src1_0 = __lasx_xvadd_h(src1_0, src1_2);
src1_0 = __lasx_xvadd_h(src1_0, src1_4);
src1_0 = __lasx_xvadd_h(src1_0, src1_6);
src1_0 = __lasx_xvhaddw_wu_hu(src1_0, src1_0);
src1_0 = __lasx_xvhaddw_du_wu(src1_0, src1_0);
src1_0 = __lasx_xvhaddw_qu_du(src1_0, src1_0);

return (__lasx_xvpickve2gr_d(src1_0, 0) +
__lasx_xvpickve2gr_d(src1_0, 2));
return (__lasx_xvpickve2gr_w(src1_0, 0) +
__lasx_xvpickve2gr_w(src1_0, 4));
}

int32_t WelsSampleSad8x8_lasx (uint8_t* pSample1, int32_t iStride1,
Expand Down
222 changes: 172 additions & 50 deletions codec/encoder/core/loongarch/dct_lasx.c
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ void WelsIDctT4Rec_lasx (uint8_t* pRec, int32_t iStride,
tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7,
dst0, dst1, dst2, dst3;

__m256i zero = __lasx_xvldi(0);
src0 = __lasx_xvld(pDct, 0);
src4 = __lasx_xvld(pPred, 0);
src5 = __lasx_xvldx(pPred, iPredStride);
Expand All @@ -250,81 +249,204 @@ void WelsIDctT4Rec_lasx (uint8_t* pRec, int32_t iStride,
src3 = __lasx_xvpickve_d(src0, 3);

LASX_TRANSPOSE4x4_H(src0, src1, src2, src3,
tmp0, tmp1, tmp2, tmp3);
DUP4_ARG1(__lasx_vext2xv_w_h,
tmp0, tmp1, tmp2, tmp3,
src0, src1, src2, src3);
src0, src1, src2, src3);
//horizon
tmp0 = __lasx_xvadd_w(src0, src2); //0+2 sumu
tmp1 = __lasx_xvsrai_w(src3, 1);
tmp1 = __lasx_xvadd_w(src1, tmp1); //1+3 sumd
tmp2 = __lasx_xvsub_w(src0, src2); //0-2 delu
tmp3 = __lasx_xvsrai_w(src1, 1);
tmp3 = __lasx_xvsub_w(tmp3, src3); //1-3 deld
tmp0 = __lasx_xvadd_h(src0, src2); //0+2 sumu
tmp1 = __lasx_xvsrai_h(src3, 1);
tmp1 = __lasx_xvadd_h(src1, tmp1); //1+3 sumd
tmp2 = __lasx_xvsub_h(src0, src2); //0-2 delu
tmp3 = __lasx_xvsrai_h(src1, 1);
tmp3 = __lasx_xvsub_h(tmp3, src3); //1-3 deld

src0 = __lasx_xvadd_w(tmp0 ,tmp1); //0 4 8 12
src1 = __lasx_xvadd_w(tmp2, tmp3); //1 5 9 13
src2 = __lasx_xvsub_w(tmp2, tmp3); //2 6 10 14
src3 = __lasx_xvsub_w(tmp0, tmp1); //3 7 11 15
DUP4_ARG2(__lasx_xvpickev_h,
zero, src0,
zero, src1,
zero, src2,
zero, src3,
src0, src1, src2, src3);
src0 = __lasx_xvadd_h(tmp0 ,tmp1); //0 4 8 12
src1 = __lasx_xvadd_h(tmp2, tmp3); //1 5 9 13
src2 = __lasx_xvsub_h(tmp2, tmp3); //2 6 10 14
src3 = __lasx_xvsub_h(tmp0, tmp1); //3 7 11 15
//vertical
LASX_TRANSPOSE4x4_H(src0, src1, src2, src3,
tmp0, tmp1, tmp2, tmp3);
DUP4_ARG1(__lasx_vext2xv_w_h,
tmp0, tmp1, tmp2, tmp3,
src0, src1, src2, src3);
tmp0 = __lasx_xvadd_w(src0, src2); //suml
tmp1 = __lasx_xvsrai_w(src3, 1);
tmp1 = __lasx_xvadd_w(src1, tmp1); //sumr
tmp2 = __lasx_xvsub_w(src0, src2); //dell
tmp3 = __lasx_xvsrai_w(src1, 1);
tmp3 = __lasx_xvsub_w(tmp3, src3); //delr
src0, src1, src2, src3);
tmp0 = __lasx_xvadd_h(src0, src2); //suml
tmp1 = __lasx_xvsrai_h(src3, 1);
tmp1 = __lasx_xvadd_h(src1, tmp1); //sumr
tmp2 = __lasx_xvsub_h(src0, src2); //dell
tmp3 = __lasx_xvsrai_h(src1, 1);
tmp3 = __lasx_xvsub_h(tmp3, src3); //delr

dst0 = __lasx_xvadd_w(tmp0, tmp1);
dst1 = __lasx_xvadd_w(tmp2, tmp3);
dst2 = __lasx_xvsub_w(tmp2, tmp3);
dst3 = __lasx_xvsub_w(tmp0, tmp1);
DUP4_ARG2(__lasx_xvsrari_w,
dst0 = __lasx_xvadd_h(tmp0, tmp1);
dst1 = __lasx_xvadd_h(tmp2, tmp3);
dst2 = __lasx_xvsub_h(tmp2, tmp3);
dst3 = __lasx_xvsub_h(tmp0, tmp1);
DUP4_ARG2(__lasx_xvsrari_h,
dst0, 6,
dst1, 6,
dst2, 6,
dst3, 6,
dst0, dst1, dst2, dst3);
DUP4_ARG1(__lasx_vext2xv_wu_bu,
DUP4_ARG1(__lasx_vext2xv_hu_bu,
src4, src5, src6, src7,
tmp4, tmp5, tmp6, tmp7);
DUP4_ARG2(__lasx_xvadd_w,
DUP4_ARG2(__lasx_xvsadd_h,
tmp4, dst0,
tmp5, dst1,
tmp6, dst2,
tmp7, dst3,
dst0, dst1, dst2, dst3);
DUP4_ARG1(__lasx_xvclip255_w,
DUP4_ARG1(__lasx_xvclip255_h,
dst0, dst1, dst2, dst3,
dst0, dst1, dst2, dst3);
DUP2_ARG2(__lasx_xvpickev_h,
DUP2_ARG2(__lasx_xvpickev_b,
dst1, dst0,
dst3, dst2,
dst0, dst2);
dst0 = __lasx_xvpickev_b(dst2, dst0);
__lasx_xvstelm_w(dst0, pRec, 0, 0);
__lasx_xvstelm_w(dst0, pRec + iStride, 0, 1);
__lasx_xvstelm_w(dst0, pRec + iDstStride_x2, 0, 2);
__lasx_xvstelm_w(dst0, pRec + iDstStride_x3, 0, 3);
__lasx_xvstelm_w(dst0, pRec + iStride, 0, 2);
__lasx_xvstelm_w(dst2, pRec + iDstStride_x2, 0, 0);
__lasx_xvstelm_w(dst2, pRec + iDstStride_x3, 0, 2);
}

void WelsIDctFourT4Rec_lasx (uint8_t* pRec, int32_t iStride,
uint8_t* pPred, int32_t iPredStride,
int16_t* pDct) {
int32_t iDstStridex4 = iStride << 2;
int32_t iPredStridex4 = iPredStride << 2;
WelsIDctT4Rec_lasx (pRec, iStride, pPred, iPredStride, pDct);
WelsIDctT4Rec_lasx (&pRec[4], iStride, &pPred[4], iPredStride, pDct + 16);
WelsIDctT4Rec_lasx (&pRec[iDstStridex4 ], iStride, &pPred[iPredStridex4 ], iPredStride, pDct + 32);
WelsIDctT4Rec_lasx (&pRec[iDstStridex4 + 4], iStride, &pPred[iPredStridex4 + 4], iPredStride, pDct + 48);
__m256i src0, src1, src2, src3, src4, src5, src6, src7;
__m256i sumu, delu, sumd, deld, SumL, DelL, DelR, SumR;
__m256i vec0, vec1, vec2, vec3, vec4, vec5, vec6, vec7;
__m256i tmp0;
DUP4_ARG2(__lasx_xvld,
pDct, 0,
pDct, 32,
pDct, 64,
pDct, 96,
src0, src2, src4, src6);
DUP4_ARG3(__lasx_xvpermi_q,
src0, src0, 0x31,
src2, src2, 0x31,
src4, src4, 0x31,
src6, src6, 0x31,
src1, src3, src5, src7);
LASX_TRANSPOSE8x8_H(src0, src1, src2, src3, src4, src5, src6, src7,
src0, src1, src2, src3, src4, src5, src6, src7);
sumu = __lasx_xvadd_h(src0, src2);
delu = __lasx_xvsub_h(src0, src2);
tmp0 = __lasx_xvsrai_h(src3, 1);
sumd = __lasx_xvadd_h(src1, tmp0);
tmp0 = __lasx_xvsrai_h(src1, 1);
deld = __lasx_xvsub_h(tmp0, src3);
src0 = __lasx_xvadd_h(sumu, sumd);
src1 = __lasx_xvadd_h(delu, deld);
src2 = __lasx_xvsub_h(delu, deld);
src3 = __lasx_xvsub_h(sumu, sumd);
sumu = __lasx_xvadd_h(src4, src6);
delu = __lasx_xvsub_h(src4, src6);
tmp0 = __lasx_xvsrai_h(src7, 1);
sumd = __lasx_xvadd_h(src5, tmp0);
tmp0 = __lasx_xvsrai_h(src5, 1);
deld = __lasx_xvsub_h(tmp0, src7);
src4 = __lasx_xvadd_h(sumu, sumd);
src5 = __lasx_xvadd_h(delu, deld);
src6 = __lasx_xvsub_h(delu, deld);
src7 = __lasx_xvsub_h(sumu, sumd);
LASX_TRANSPOSE8x8_H(src0, src1, src2, src3, src4, src5, src6, src7,
src0, src1, src2, src3, src4, src5, src6, src7);
src0 = __lasx_xvpermi_q(src2, src0, 0x20);
src1 = __lasx_xvpermi_q(src3, src1, 0x20);
src4 = __lasx_xvpermi_q(src6, src4, 0x20);
src5 = __lasx_xvpermi_q(src7, src5, 0x20);
SumL = __lasx_xvadd_h(src0, src1);
DelL = __lasx_xvsub_h(src0, src1);
tmp0 = __lasx_xvsrai_h(src0, 1);
DelR = __lasx_xvsub_h(tmp0, src1);
tmp0 = __lasx_xvsrai_h(src1, 1);
SumR = __lasx_xvadd_h(src0, tmp0);
SumR = __lasx_xvbsrl_v(SumR, 8);
DelR = __lasx_xvbsrl_v(DelR, 8);
src0 = __lasx_xvadd_h(SumL, SumR);
src1 = __lasx_xvadd_h(DelL, DelR);
src2 = __lasx_xvsub_h(DelL, DelR);
src3 = __lasx_xvsub_h(SumL, SumR);
SumL = __lasx_xvadd_h(src4, src5);
DelL = __lasx_xvsub_h(src4, src5);
tmp0 = __lasx_xvsrai_h(src4, 1);
DelR = __lasx_xvsub_h(tmp0, src5);
tmp0 = __lasx_xvsrai_h(src5, 1);
SumR = __lasx_xvadd_h(src4, tmp0);
SumR = __lasx_xvbsrl_v(SumR, 8);
DelR = __lasx_xvbsrl_v(DelR, 8);
src4 = __lasx_xvadd_h(SumL, SumR);
src5 = __lasx_xvadd_h(DelL, DelR);
src6 = __lasx_xvsub_h(DelL, DelR);
src7 = __lasx_xvsub_h(SumL, SumR);
DUP4_ARG2(__lasx_xvsrari_h,
src0, 6,
src1, 6,
src2, 6,
src3, 6,
src0, src1, src2, src3);
DUP4_ARG2(__lasx_xvsrari_h,
src4, 6,
src5, 6,
src6, 6,
src7, 6,
src4, src5, src6, src7);
DUP4_ARG2(__lasx_xvpermi_d,
src0, 0xd8,
src1, 0xd8,
src2, 0xd8,
src3, 0xd8,
src0, src1, src2, src3);
DUP4_ARG2(__lasx_xvpermi_d,
src4, 0xd8,
src5, 0xd8,
src6, 0xd8,
src7, 0xd8,
src4, src5, src6, src7);
DUP4_ARG2(__lasx_xvldx,
pPred, iPredStride*0,
pPred, iPredStride,
pPred, iPredStride*2,
pPred, iPredStride*3,
vec0, vec1, vec2, vec3);
pPred += iPredStride*4;
DUP4_ARG2(__lasx_xvldx,
pPred, iPredStride*0,
pPred, iPredStride,
pPred, iPredStride*2,
pPred, iPredStride*3,
vec4, vec5, vec6, vec7);
DUP4_ARG1(__lasx_vext2xv_hu_bu,
vec0, vec1, vec2, vec3,
vec0, vec1, vec2, vec3);
DUP4_ARG1(__lasx_vext2xv_hu_bu,
vec4, vec5, vec6, vec7,
vec4, vec5, vec6, vec7);
DUP4_ARG2(__lasx_xvsadd_h,
src0, vec0,
src1, vec1,
src2, vec2,
src3, vec3,
src0, src1, src2, src3);
DUP4_ARG2(__lasx_xvsadd_h,
src4, vec4,
src5, vec5,
src6, vec6,
src7, vec7,
src4, src5, src6, src7);
DUP4_ARG1(__lasx_xvclip255_h,
src0, src1, src2, src3,
src0, src1, src2, src3);
DUP4_ARG1(__lasx_xvclip255_h,
src4, src5, src6, src7,
src4, src5, src6, src7);
DUP4_ARG2(__lasx_xvpickev_b,
src1, src0, src3, src2,
src5, src4, src7, src6,
src0, src2, src4, src6);
__lasx_xvstelm_d(src0, pRec, 0, 0);
__lasx_xvstelm_d(src0, pRec + iStride, 0, 1);
__lasx_xvstelm_d(src2, pRec + iStride*2, 0, 0);
__lasx_xvstelm_d(src2, pRec + iStride*3, 0, 1);
pRec += iStride*4;
__lasx_xvstelm_d(src4, pRec, 0, 0);
__lasx_xvstelm_d(src4, pRec + iStride, 0, 1);
__lasx_xvstelm_d(src6, pRec + iStride*2, 0, 0);
__lasx_xvstelm_d(src6, pRec + iStride*3, 0, 1);
}
3 changes: 0 additions & 3 deletions codec/encoder/core/loongarch/sample_lasx.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@
*
**********************************************************************************
*/

#include <stdint.h>
#include "sad_common.h"
#include "loongson_intrinsics.h"

void WelsIChromaPredV_lasx (uint8_t* pPred, uint8_t* pRef, const int32_t kiStride);
void WelsIChromaPredH_lasx (uint8_t* pPred, uint8_t* pRef, const int32_t kiStride);
Expand Down
1 change: 1 addition & 0 deletions codec/processing/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ elif cpu_family == 'aarch64'
elif cpu_family in ['loongarch32', 'loongarch64']
asm_sources = [
'src/loongarch/vaa_lsx.c',
'src/loongarch/vaa_lasx.c',
]
cpp_sources += asm_sources
else
Expand Down
Loading

0 comments on commit 9866066

Please sign in to comment.