# Generate INT8 Code

In [22]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin
N_UNROLL = 16
X_DATA_BYTES = 4
SIMD_SIZE = 8
UNROLL_REMAIN = False
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G"+str(N_UNROLL)+"_AVX2_OpenMP"
function_name +="(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("            int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
for sn in range(N_SIMD):
    print("            __m256 res"+str(sn)+" = _mm256_setzero_ps();")
print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m256i indices_p"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * 8)+"));") 
for sn in range(N_SIMD): 
    print("                __m256i indices_n"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * 8 + 8 * N_SIMD)+"));")
for sn in range(N_SIMD): 
    print("                __m256 pos"+str(sn)+" = _mm256_i32gather_ps(X, indices_p"+str(sn)+", "+str(X_DATA_BYTES)+");") 
for sn in range(N_SIMD):  
    print("                __m256 neg"+str(sn)+" = _mm256_i32gather_ps(X, indices_n"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm256_add_ps(res"+idx+", _mm256_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm256_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
if(UNROLL_REMAIN):
    for sn in range(N_UNROLL): 
        print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++)")
        print("                    Ybase["+str(sn)+"] += X[row_index[k]];")
        print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++)")
        print("                    Ybase["+str(sn)+"] -= X[row_index[k]];")
else:
        print("            for (int g = 0; g < "+str(N_UNROLL)+"; g++) {")
        print("                for (int k = groupData[2 * g + 1]; k < groupData[2 * g + 2]; k++)")
        print("                    Ybase[g] += X[row_index[k]];")
        print("                for (int k = groupData[2 * g + 2]; k < groupData[2 * g + 3]; k++)")
        print("                    Ybase[g] -= X[row_index[k]];")
        print("            }")
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G16_AVX2_OpenMP(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 16; j++) {
            int* groupData = &metadata[j * 34]; 
            __m256 res0 = _mm256_setzero_ps();
            __m256 res1 = _mm256_setzero_ps();
            for (int k = groupData[0]; k < groupData[1]; k += 32) {
                __m256i indices_p0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 0));
                __m256i indices_p1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 8));
                __m256i indices_n0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 16));
                __m256i indices_n1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 24));
                __m256 pos0 = _mm256_i32gather_ps(X, indices_p0, 4);
                __m256 pos1 = _mm256_i32gather_ps(X, i

In [21]:
# Generate INT8 GEMM Code: Uniform and General, AVX2 and AVX-512
M_UNROLL = 64
N_UNROLL = 4
X_DATA_BYTES = 1

UNIFORM=True
AVX512=True
UNROLL_REMAIN = False

if(UNIFORM):
    TCSC_FORMAT = "Uniform"
    INPUT_META = "const int32_t NonZeroPerCol"
    K_START = "j * "+str(N_UNROLL)+" * NonZeroPerCol"
    K_END = "(j + 1) * "+str(N_UNROLL)+" * NonZeroPerCol"
else:
    TCSC_FORMAT = "Merged_GroupMin"
    INPUT_META = "const int32_t* metadata"
    K_START = "groupData[0]"
    K_END = "groupData[1]"

if(AVX512):
    SIMD_SIZE = 64
    AVX_NAME = "AVX512"
    AVX = "__m512i"
    SET_R = "_mm512_setzero_si512"
    LOAD_W = "_mm512_load_si512(reinterpret_cast<const __m512i*>"
    LOAD_X = "_mm512_load_si512(reinterpret_cast<const __m512i*>"
    SUB_X = "_mm512_sub_epi8"
    ADD_X = "_mm512_add_epi8"
    STORE_Y = "_mm512_store_si512(reinterpret_cast<__m512i*>"
else:
    SIMD_SIZE = 32
    AVX_NAME = "AVX2"
    AVX = "__m256i"
    SET_R = "_mm256_setzero_si256"
    LOAD_W = "_mm256_load_si256(reinterpret_cast<const __m256i*>"
    LOAD_X = "_mm256_load_si256(reinterpret_cast<const __m256i*>"
    SUB_X = "_mm256_sub_epi8"
    ADD_X = "_mm256_add_epi8"
    STORE_Y = "_mm256_store_si256(reinterpret_cast<__m256i*>"
    
M_SIMD = int(M_UNROLL/SIMD_SIZE)

function_name = "void GEMM_CPU_INT8_colMajor_TCSC_"+TCSC_FORMAT+"_"+str(M_UNROLL)+"xG"+str(N_UNROLL)+"_"+AVX_NAME+"_OpenMP"
function_name +="(const int8_t* X, "+INPUT_META+", const int16_t* row_index, int8_t* result, const int M_ROW, const int N_COL, const int K){"
print(function_name)
print("    #pragma omp parallel for")
print("    for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {") 
if(not UNIFORM):
    print("        const int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
print("        for (int i = 0; i < M_ROW; i +="+str(M_UNROLL)+") {")
for sn in range(N_UNROLL):
    for sm in range(M_SIMD):
        print("            "+AVX+" res"+str(sn)+str(sm)+" = "+SET_R+"();")    
print("            for (int k = "+K_START+"; k < "+K_END+"; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_UNROLL): 
    for sm in range(M_SIMD):  
        print("                "+AVX+" pos"+str(sn)+str(sm)+" = "+LOAD_X+"(X + row_index[k + "+str(sn)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+"));")   
for sn in range(N_UNROLL): 
    for sm in range(M_SIMD):
        print("                "+AVX+" neg"+str(sn)+str(sm)+" = "+LOAD_X+"(X + row_index[k + "+str(sn+N_UNROLL)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+"));")
for sn in range(N_UNROLL): 
    for sm in range(M_SIMD):
        idx = str(sn)+str(sm)
        print("                res"+idx+" = "+ADD_X+"(res"+idx+", "+SUB_X+"(pos"+idx+", neg"+idx+"));")
print("            }")

if(not UNIFORM):
    for sn in range(N_UNROLL):
        print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++) {")
        for sm in range(M_SIMD):
            print("                    res"+str(sn)+str(sm)+" = "+ADD_X+"(res"+str(sn)+str(sm)+", "+LOAD_X+"(X + row_index[k] * M_ROW + i + "+str(sm*SIMD_SIZE)+")));")
        print("                }")
        print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++) {")
        for sm in range(M_SIMD):
            print("                    res"+str(sn)+str(sm)+" = "+SUB_X+"(res"+str(sn)+str(sm)+", "+LOAD_X+"(X + row_index[k] * M_ROW + i + "+str(sm*SIMD_SIZE)+")));")
        print("                }")

for sn in range(N_UNROLL): 
    for sm in range(M_SIMD):     
       print("            "+STORE_Y+"(result + (j * "+str(N_UNROLL)+" + "+str(sn)+") * M_ROW  + i + "+str(sm*SIMD_SIZE)+"), res"+str(sn)+str(sm)+");") 

print("        }")
print("    }")
print("}")

void GEMM_CPU_INT8_colMajor_TCSC_Uniform_64xG4_AVX512_OpenMP(const int8_t* X, const int32_t NonZeroPerCol, const int16_t* row_index, int8_t* result, const int M_ROW, const int N_COL, const int K){
    #pragma omp parallel for
    for (int j = 0; j < N_COL / 4; j++) {
        for (int i = 0; i < M_ROW; i +=64) {
            __m512i res00 = _mm512_setzero_si512();
            __m512i res10 = _mm512_setzero_si512();
            __m512i res20 = _mm512_setzero_si512();
            __m512i res30 = _mm512_setzero_si512();
            for (int k = j * 4 * NonZeroPerCol; k < (j + 1) * 4 * NonZeroPerCol; k += 8) {
                __m512i pos00 = _mm512_load_si512(reinterpret_cast<const __m512i*>(X + row_index[k + 0] * M_ROW + i + 0));
                __m512i pos10 = _mm512_load_si512(reinterpret_cast<const __m512i*>(X + row_index[k + 1] * M_ROW + i + 0));
                __m512i pos20 = _mm512_load_si512(reinterpret_cast<const __m512i*>(X + row_index[k + 2] * M_ROW + i + 0));
                __m

# Generate SiLU + Pointwise Mul

In [2]:
# SiLU
UNROLL_SIZE = 16
for un in range(UNROLL_SIZE):
    print("        T a"+str(un)+" = X[i + "+str(un)+"];")
for un in range(UNROLL_SIZE):
    print("        T b"+str(un)+" = XB[i + "+str(un)+"];")    
for un in range(UNROLL_SIZE):
    print("        T d"+str(un)+" = 1 + std::exp( - a"+str(un)+");")
for un in range(UNROLL_SIZE):
    print("        a"+str(un)+" = a"+str(un)+" * b"+str(un)+" /  d"+str(un)+";")
for un in range(UNROLL_SIZE):
    print("        X[i + "+str(un)+"] = a"+str(un)+";")

        T a0 = X[i + 0];
        T a1 = X[i + 1];
        T a2 = X[i + 2];
        T a3 = X[i + 3];
        T a4 = X[i + 4];
        T a5 = X[i + 5];
        T a6 = X[i + 6];
        T a7 = X[i + 7];
        T a8 = X[i + 8];
        T a9 = X[i + 9];
        T a10 = X[i + 10];
        T a11 = X[i + 11];
        T a12 = X[i + 12];
        T a13 = X[i + 13];
        T a14 = X[i + 14];
        T a15 = X[i + 15];
        T b0 = XB[i + 0];
        T b1 = XB[i + 1];
        T b2 = XB[i + 2];
        T b3 = XB[i + 3];
        T b4 = XB[i + 4];
        T b5 = XB[i + 5];
        T b6 = XB[i + 6];
        T b7 = XB[i + 7];
        T b8 = XB[i + 8];
        T b9 = XB[i + 9];
        T b10 = XB[i + 10];
        T b11 = XB[i + 11];
        T b12 = XB[i + 12];
        T b13 = XB[i + 13];
        T b14 = XB[i + 14];
        T b15 = XB[i + 15];
        T d0 = 1 + std::exp( - a0);
        T d1 = 1 + std::exp( - a1);
        T d2 = 1 + std::exp( - a2);
        T d3 = 1 + std::exp( - a3);
        T d4 = 1

In [5]:
UNROLL_SIZE = 16
for un in range(UNROLL_SIZE):
    print("        T a"+str(un)+" = X[i + j + "+str(un)+"];")
for un in range(UNROLL_SIZE):
    print("        T b"+str(un)+" = XB[i + j + "+str(un)+"];")    
for un in range(UNROLL_SIZE):
    print("        T d"+str(un)+" = std::exp( - a"+str(un)+");")
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = d"+str(un)+" + 1;")
for un in range(UNROLL_SIZE):
    print("        a"+str(un)+" = a"+str(un)+" * b"+str(un)+" /  d"+str(un)+";")
for un in range(UNROLL_SIZE):
    print("        X[i + j + "+str(un)+"] = a"+str(un)+";")

        T a0 = X[i + j + 0];
        T a1 = X[i + j + 1];
        T a2 = X[i + j + 2];
        T a3 = X[i + j + 3];
        T a4 = X[i + j + 4];
        T a5 = X[i + j + 5];
        T a6 = X[i + j + 6];
        T a7 = X[i + j + 7];
        T a8 = X[i + j + 8];
        T a9 = X[i + j + 9];
        T a10 = X[i + j + 10];
        T a11 = X[i + j + 11];
        T a12 = X[i + j + 12];
        T a13 = X[i + j + 13];
        T a14 = X[i + j + 14];
        T a15 = X[i + j + 15];
        T b0 = XB[i + j + 0];
        T b1 = XB[i + j + 1];
        T b2 = XB[i + j + 2];
        T b3 = XB[i + j + 3];
        T b4 = XB[i + j + 4];
        T b5 = XB[i + j + 5];
        T b6 = XB[i + j + 6];
        T b7 = XB[i + j + 7];
        T b8 = XB[i + j + 8];
        T b9 = XB[i + j + 9];
        T b10 = XB[i + j + 10];
        T b11 = XB[i + j + 11];
        T b12 = XB[i + j + 12];
        T b13 = XB[i + j + 13];
        T b14 = XB[i + j + 14];
        T b15 = XB[i + j + 15];
        T d0 = std::exp( - a0);


In [1]:
SIMD_SIZE = 8
UNROLL_SIZE = 8
for un in range(UNROLL_SIZE):
    print("        __m256 a"+str(un)+" = _mm256_load_ps(&X[i + "+str(un * SIMD_SIZE)+"]);")
for un in range(UNROLL_SIZE):
    print("        __m256 b"+str(un)+" = _mm256_load_ps(&XB[i + "+str(un * SIMD_SIZE)+"]);")    
for un in range(UNROLL_SIZE):
    print("        __m256 d"+str(un)+" = _mm256_sub_ps(zeros, a"+str(un)+");")
for un in range(UNROLL_SIZE):
    print("        a"+str(un)+" = _mm256_mul_ps(a"+str(un)+", b"+str(un)+");")    
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = _mm256_exp_ps( d"+str(un)+");")
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = _mm256_add_ps(d"+str(un)+", ones);")
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = _mm256_rcp_ps(d"+str(un)+");")
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = _mm256_mul_ps(a"+str(un)+", d"+str(un)+");")  
for un in range(UNROLL_SIZE):
    print("        _mm256_store_ps(&X[i + "+str(un * SIMD_SIZE)+"], d"+str(un)+");")

        __m256 a0 = _mm256_load_ps(&X[i + 0]);
        __m256 a1 = _mm256_load_ps(&X[i + 8]);
        __m256 a2 = _mm256_load_ps(&X[i + 16]);
        __m256 a3 = _mm256_load_ps(&X[i + 24]);
        __m256 a4 = _mm256_load_ps(&X[i + 32]);
        __m256 a5 = _mm256_load_ps(&X[i + 40]);
        __m256 a6 = _mm256_load_ps(&X[i + 48]);
        __m256 a7 = _mm256_load_ps(&X[i + 56]);
        __m256 b0 = _mm256_load_ps(&XB[i + 0]);
        __m256 b1 = _mm256_load_ps(&XB[i + 8]);
        __m256 b2 = _mm256_load_ps(&XB[i + 16]);
        __m256 b3 = _mm256_load_ps(&XB[i + 24]);
        __m256 b4 = _mm256_load_ps(&XB[i + 32]);
        __m256 b5 = _mm256_load_ps(&XB[i + 40]);
        __m256 b6 = _mm256_load_ps(&XB[i + 48]);
        __m256 b7 = _mm256_load_ps(&XB[i + 56]);
        __m256 d0 = _mm256_sub_ps(zeros, a0);
        __m256 d1 = _mm256_sub_ps(zeros, a1);
        __m256 d2 = _mm256_sub_ps(zeros, a2);
        __m256 d3 = _mm256_sub_ps(zeros, a3);
        __m256 d4 = _mm256_sub_ps(zeros, a4)

In [2]:
SIMD_SIZE = 16
UNROLL_SIZE = 8
for un in range(UNROLL_SIZE):
    print("        __m512 a"+str(un)+" = _mm512_load_ps(&X[i + "+str(un * SIMD_SIZE)+"]);")
for un in range(UNROLL_SIZE):
    print("        __m512 b"+str(un)+" = _mm512_load_ps(&XB[i + "+str(un * SIMD_SIZE)+"]);")    
for un in range(UNROLL_SIZE):
    print("        __m512 d"+str(un)+" = _mm512_sub_ps(zeros, a"+str(un)+");")
for un in range(UNROLL_SIZE):
    print("        a"+str(un)+" = _mm512_mul_ps(a"+str(un)+", b"+str(un)+");")    
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = _mm512_exp_ps( d"+str(un)+");")
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = _mm512_add_ps(d"+str(un)+", ones);")
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = _mm512_rcp_ps(d"+str(un)+");")
for un in range(UNROLL_SIZE):
    print("        d"+str(un)+" = _mm512_mul_ps(a"+str(un)+", d"+str(un)+");")  
for un in range(UNROLL_SIZE):
    print("        _mm512_store_ps(&X[i + "+str(un * SIMD_SIZE)+"], d"+str(un)+");")

        __m512 a0 = _mm512_load_ps(&X[i + 0]);
        __m512 a1 = _mm512_load_ps(&X[i + 16]);
        __m512 a2 = _mm512_load_ps(&X[i + 32]);
        __m512 a3 = _mm512_load_ps(&X[i + 48]);
        __m512 a4 = _mm512_load_ps(&X[i + 64]);
        __m512 a5 = _mm512_load_ps(&X[i + 80]);
        __m512 a6 = _mm512_load_ps(&X[i + 96]);
        __m512 a7 = _mm512_load_ps(&X[i + 112]);
        __m512 b0 = _mm512_load_ps(&XB[i + 0]);
        __m512 b1 = _mm512_load_ps(&XB[i + 16]);
        __m512 b2 = _mm512_load_ps(&XB[i + 32]);
        __m512 b3 = _mm512_load_ps(&XB[i + 48]);
        __m512 b4 = _mm512_load_ps(&XB[i + 64]);
        __m512 b5 = _mm512_load_ps(&XB[i + 80]);
        __m512 b6 = _mm512_load_ps(&XB[i + 96]);
        __m512 b7 = _mm512_load_ps(&XB[i + 112]);
        __m512 d0 = _mm512_sub_ps(zeros, a0);
        __m512 d1 = _mm512_sub_ps(zeros, a1);
        __m512 d2 = _mm512_sub_ps(zeros, a2);
        __m512 d3 = _mm512_sub_ps(zeros, a3);
        __m512 d4 = _mm512_sub_ps(zeros,

# SIMD Code Generator

In [5]:
# Generate GEMM_CPU_FP32_rowMajor_TCSC_Merged_GroupMin
M_UNROLL = 2
N_UNROLL = 64
X_DATA_BYTES = 4
SIMD_SIZE = 8
UNROLL_REMAIN = False
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMM_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_"+str(M_UNROLL)+"xG"+str(N_UNROLL)+"_AVX2_OpenMP"
function_name +="(float* X, int32_t* metadata, int32_t* row_index, float* result, int M_ROW, int N_COL, int K){"
print(function_name)
print("    #pragma omp parallel for")
print("    for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("        for (int i = 0; i < M_ROW; i +="+str(M_UNROLL)+") {")
print("            float* Xbase = X + i * K;")
print("            int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
for sn in range(N_SIMD):
    for sm in range(M_UNROLL):
        print("            __m256 res"+str(sn)+str(sm)+" = _mm256_setzero_ps();")
print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m256i indices_p"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * 8)+"));") 
for sn in range(N_SIMD): 
    print("                __m256i indices_n"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * 8 + 8*N_SIMD)+"));")
for sm in range(M_UNROLL):
    for sn in range(N_SIMD):    
        print("                __m256 pos"+str(sn)+str(sm)+" = _mm256_i32gather_ps(Xbase + K * "+str(sm)+", indices_p"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sm in range(M_UNROLL):
    for sn in range(N_SIMD):    
        print("                __m256 neg"+str(sn)+str(sm)+" = _mm256_i32gather_ps(Xbase + K * "+str(sm)+", indices_n"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sm in range(M_UNROLL):
    for sn in range(N_SIMD):    
        idx = str(sn)+str(sm)
        print("                res"+idx+" = _mm256_add_ps(res"+idx+", _mm256_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + i * N_COL + j * "+str(N_UNROLL)+";")
for sm in range(M_UNROLL):
    for sn in range(N_SIMD):    
       print("            _mm256_store_ps(Ybase + N_COL * "+str(sm)+" + "+str(SIMD_SIZE*sn)+", res"+str(sn)+str(sm)+");") 


if(UNROLL_REMAIN):
    for sn in range(N_UNROLL):
        print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++) {")
        for sm in range(M_UNROLL):
            print("                    Ybase[N_COL * "+str(sm)+" + "+str(sn)+"] += Xbase[K * "+str(sm)+" + row_index[k]];")
        print("                }")
        print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++) {")
        for sm in range(M_UNROLL):
            print("                    Ybase[N_COL * "+str(sm)+" + "+str(sn)+"] -= Xbase[K * "+str(sm)+" + row_index[k]];")
        print("                }")
else:
    print("            for (int g = 0; g < "+str(N_UNROLL)+"; g++) {")
    print("                for (int k = groupData[2 * g + 1]; k < groupData[2 * g + 2]; k++) {")
    for sm in range(M_UNROLL):
        print("                    Ybase[N_COL * "+str(sm)+" + g] += Xbase[K * "+str(sm)+" + row_index[k]];")
    print("                }")
    print("                for (int k = groupData[2 * g + 2]; k < groupData[2 * g + 3]; k++) {")
    for sm in range(M_UNROLL):
        print("                    Ybase[N_COL * "+str(sm)+" + g] -= Xbase[K * "+str(sm)+" + row_index[k]];")
    print("                }")
    print("            }")
print("        }")
print("    }")
print("}")

void GEMM_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_2xG64_AVX2_OpenMP(float* X, int32_t* metadata, int32_t* row_index, float* result, int M_ROW, int N_COL, int K){
    #pragma omp parallel for
    for (int j = 0; j < N_COL / 64; j++) {
        for (int i = 0; i < M_ROW; i +=2) {
            float* Xbase = X + i * K;
            int* groupData = &metadata[j * 130]; 
            __m256 res00 = _mm256_setzero_ps();
            __m256 res01 = _mm256_setzero_ps();
            __m256 res10 = _mm256_setzero_ps();
            __m256 res11 = _mm256_setzero_ps();
            __m256 res20 = _mm256_setzero_ps();
            __m256 res21 = _mm256_setzero_ps();
            __m256 res30 = _mm256_setzero_ps();
            __m256 res31 = _mm256_setzero_ps();
            __m256 res40 = _mm256_setzero_ps();
            __m256 res41 = _mm256_setzero_ps();
            __m256 res50 = _mm256_setzero_ps();
            __m256 res51 = _mm256_setzero_ps();
            __m256 res60 = _mm256_setzero_ps();
           

In [None]:
void GEMM_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_1xG8_AVX2_OpenMP(float* X, int32_t* metadata, int32_t* row_index, float* result, int M_ROW, int N_COL, int K) {
    for (int i = 0; i < M_ROW; i ++) {
    #pragma omp parallel for
        for (int j = 0; j < N_COL / 8; j++) {
            /* Pointer to where a column starts and ends
            int align_start  = metadata[j * 10 + 0];
            int align_end    = metadata[j * 10 + 1];
            int +remain_end0 = metadata[j * 10 + 2];
            int -remain_end0 = metadata[j * 10 + 3];
            int +remain_end1 = metadata[j * 10 + 4];
            int -remain_end1 = metadata[j * 10 + 5];
            ...
            */
            // Group # = j, metadata per group = 18
            int* groupData = &metadata[j * 18];        
            __m256 res0 = _mm256_setzero_ps();
            float* Xbase = X + i * K;
            float* Ybase = result + i * N_COL + j * 8;
            for (int k = groupData[0]; k < groupData[1]; k += 16) {
                __m256i indices_p = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 0));
                __m256i indices_n = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 8));
                //__m128i indices_p = _mm128_load_si128(reinterpret_cast<const __m128i*>(row_index + k + 0));
                //__m128i indices_n = _mm128_load_si128(reinterpret_cast<const __m128i*>(row_index + k + 8));
                __m256 pos0 = _mm256_i32gather_ps(Xbase, indices_p, 4);
                __m256 neg0 = _mm256_i32gather_ps(Xbase, indices_n, 4);
                res0 = _mm256_add_ps(res0, _mm256_sub_ps(pos0, neg0));
            }
            _mm256_store_ps(Ybase, res0);

            #pragma unroll(8)
            for (int g = 0; g < 8; g++) {
                float pos00 = 0;
                for (int k = groupData[2 * g + 1]; k < groupData[2 * g+2]; k++) {
                    pos00 += Xbase[row_index[k]];
                }
                Ybase[g]+=pos00;
                float neg00 = 0;
                for (int k = groupData[2 * g+2]; k < groupData[2 * g+3]; k++) {
                    neg00 += Xbase[row_index[k]];
                }
                Ybase[g] -= neg00;
            }          
        }
    }
}


In [11]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin
N_UNROLL = 16
X_DATA_BYTES = 4
SIMD_SIZE = 8
UNROLL_REMAIN = False
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G"+str(N_UNROLL)+"_AVX2_OpenMP"
function_name +="(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("            int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
for sn in range(N_SIMD):
    print("            __m256 res"+str(sn)+" = _mm256_setzero_ps();")
print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m256i indices_p"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * 8)+"));") 
for sn in range(N_SIMD): 
    print("                __m256i indices_n"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * 8 + 8 * N_SIMD)+"));")
for sn in range(N_SIMD): 
    print("                __m256 pos"+str(sn)+" = _mm256_i32gather_ps(X, indices_p"+str(sn)+", "+str(X_DATA_BYTES)+");") 
for sn in range(N_SIMD):  
    print("                __m256 neg"+str(sn)+" = _mm256_i32gather_ps(X, indices_n"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm256_add_ps(res"+idx+", _mm256_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm256_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
if(UNROLL_REMAIN):
    for sn in range(N_UNROLL): 
        print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++)")
        print("                    Ybase["+str(sn)+"] += X[row_index[k]];")
        print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++)")
        print("                    Ybase["+str(sn)+"] -= X[row_index[k]];")
else:
        print("            for (int g = 0; g < "+str(N_UNROLL)+"; g++) {")
        print("                for (int k = groupData[2 * g + 1]; k < groupData[2 * g + 2]; k++)")
        print("                    Ybase[g] += X[row_index[k]];")
        print("                for (int k = groupData[2 * g + 2]; k < groupData[2 * g + 3]; k++)")
        print("                    Ybase[g] -= X[row_index[k]];")
        print("            }")
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G16_AVX2_OpenMP(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 16; j++) {
            int* groupData = &metadata[j * 34]; 
            __m256 res0 = _mm256_setzero_ps();
            __m256 res1 = _mm256_setzero_ps();
            for (int k = groupData[0]; k < groupData[1]; k += 32) {
                __m256i indices_p0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 0));
                __m256i indices_p1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 8));
                __m256i indices_n0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 16));
                __m256i indices_n1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 24));
                __m256 pos0 = _mm256_i32gather_ps(X, indices_p0, 4);
                __m256 pos1 = _mm256_i32gather_ps(X, i

In [33]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Uniform
N_UNROLL = 64
X_DATA_BYTES = 4
SIMD_SIZE = 8
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Uniform_G"+str(N_UNROLL)+"_AVX2_OpenMP"
function_name +="(float* X, int32_t NonZeroPerCol, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
for sn in range(N_SIMD):
    print("            __m256 res"+str(sn)+" = _mm256_setzero_ps();")
print("            for (int k = j * "+str(N_UNROLL)+" * NonZeroPerCol; k < (j + 1) * "+str(N_UNROLL)+" * NonZeroPerCol; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m256i indices_p"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * 8)+"));") 
for sn in range(N_SIMD): 
    print("                __m256i indices_n"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * 8 + 8 * N_SIMD)+"));")
for sn in range(N_SIMD): 
    print("                __m256 pos"+str(sn)+" = _mm256_i32gather_ps(X, indices_p"+str(sn)+", "+str(X_DATA_BYTES)+");") 
for sn in range(N_SIMD):  
    print("                __m256 neg"+str(sn)+" = _mm256_i32gather_ps(X, indices_n"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm256_add_ps(res"+idx+", _mm256_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm256_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Uniform_G64_AVX2_OpenMP(float* X, int32_t NonZeroPerCol, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 64; j++) {
            __m256 res0 = _mm256_setzero_ps();
            __m256 res1 = _mm256_setzero_ps();
            __m256 res2 = _mm256_setzero_ps();
            __m256 res3 = _mm256_setzero_ps();
            __m256 res4 = _mm256_setzero_ps();
            __m256 res5 = _mm256_setzero_ps();
            __m256 res6 = _mm256_setzero_ps();
            __m256 res7 = _mm256_setzero_ps();
            for (int k = j * 64 * NonZeroPerCol; k < (j + 1) * 64 * NonZeroPerCol; k += 128) {
                __m256i indices_p0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 0));
                __m256i indices_p1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + 8));
                __m256i indices_p2 = _mm256_load_si256(reinterpret_cast<const __m256i*>

In [23]:
# Generate GEMM_CPU_FP32_colMajor_TCSC_Merged_GroupMin
M_UNROLL = 64
N_UNROLL = 4
X_DATA_BYTES = 4
SIMD_SIZE = 8
UNROLL_REMAIN = True
M_SIMD = int(M_UNROLL/SIMD_SIZE)
function_name = "void GEMM_CPU_FP32_colMajor_TCSC_Merged_GroupMin_"+str(M_UNROLL)+"xG"+str(N_UNROLL)+"_AVX2_OpenMP"
function_name +="(float* X, int32_t* metadata, int16_t* row_index, float* result, int M_ROW, int N_COL, int K){"
print(function_name)
print("    #pragma omp parallel for")
print("    for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("        int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
print("        for (int i = 0; i < M_ROW; i +="+str(M_UNROLL)+") {")
for sn in range(N_UNROLL):
    for sm in range(M_SIMD):
        print("            __m256 res"+str(sn)+str(sm)+" = _mm256_setzero_ps();")    
print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
        print("                __m256 pos"+str(sn)+str(sm)+" = _mm256_load_ps(X + row_index[k + "+str(sn*2+0)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+");")   
        print("                __m256 neg"+str(sn)+str(sm)+" = _mm256_load_ps(X + row_index[k + "+str(sn*2+1)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+");")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
        idx = str(sn)+str(sm)
        print("                res"+idx+" = _mm256_add_ps(res"+idx+", _mm256_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")

for sn in range(N_UNROLL):
    print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++) {")
    for sm in range(M_SIMD):
        print("                    res"+str(sn)+str(sm)+" = _mm256_add_ps(res"+str(sn)+str(sm)+", _mm256_load_ps(X + row_index[k] * M_ROW + i + "+str(sm*SIMD_SIZE)+"));")
    print("                }")
    print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++) {")
    for sm in range(M_SIMD):
        print("                    res"+str(sn)+str(sm)+" = _mm256_sub_ps(res"+str(sn)+str(sm)+", _mm256_load_ps(X + row_index[k] * M_ROW + i + "+str(sm*SIMD_SIZE)+"));")
    print("                }")
    
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
       print("            _mm256_store_ps(result + (j * "+str(N_UNROLL)+" + "+str(sn)+") * M_ROW  + i + "+str(sm*SIMD_SIZE)+", res"+str(sn)+str(sm)+");") 

print("        }")
print("    }")
print("}")

void GEMM_CPU_FP32_colMajor_TCSC_Merged_GroupMin_64xG4_AVX2_OpenMP(float* X, int32_t* metadata, int16_t* row_index, float* result, int M_ROW, int N_COL, int K){
    #pragma omp parallel for
    for (int j = 0; j < N_COL / 4; j++) {
        int* groupData = &metadata[j * 10]; 
        for (int i = 0; i < M_ROW; i +=64) {
            __m256 res00 = _mm256_setzero_ps();
            __m256 res01 = _mm256_setzero_ps();
            __m256 res02 = _mm256_setzero_ps();
            __m256 res03 = _mm256_setzero_ps();
            __m256 res04 = _mm256_setzero_ps();
            __m256 res05 = _mm256_setzero_ps();
            __m256 res06 = _mm256_setzero_ps();
            __m256 res07 = _mm256_setzero_ps();
            __m256 res10 = _mm256_setzero_ps();
            __m256 res11 = _mm256_setzero_ps();
            __m256 res12 = _mm256_setzero_ps();
            __m256 res13 = _mm256_setzero_ps();
            __m256 res14 = _mm256_setzero_ps();
            __m256 res15 = _mm256_setzero_ps();
      

In [None]:
void GEMM_CPU_FP32_colMajor_TCSC_Merged_GroupMin_8xG4_AVX2_OpenMP(float* X, int32_t* metadata, int16_t* row_index, float* result, int M_ROW, int N_COL, int K) {
#pragma omp parallel for
    for (int j = 0; j < N_COL/4; j ++) {
        /* Pointer to where a column starts and ends
        int align_start  = metadata[j * 10 + 0];
        int align_end    = metadata[j * 10 + 1];
        int +remain_end0 = metadata[j * 10 + 2];
        int -remain_end0 = metadata[j * 10 + 3];
        int +remain_end1 = metadata[j * 10 + 4];
        int -remain_end1 = metadata[j * 10 + 5];
        ...
        */
        // Group # = j, metadata per group = 10
        int * groupData = &metadata[j * 10];

        for (int i = 0; i < M_ROW; i += 8) {
            __m256 res0 = _mm256_set1_ps(0.0f);
            __m256 res1 = _mm256_set1_ps(0.0f);
            __m256 res2 = _mm256_set1_ps(0.0f);
            __m256 res3 = _mm256_set1_ps(0.0f);

            for (int k = groupData[0]; k < groupData[1]; k += 8) {
                __m256 pos0 = _mm256_load_ps(&X[row_index[k + 0] * M_ROW + i]);
                __m256 neg0 = _mm256_load_ps(&X[row_index[k + 1] * M_ROW + i]);
                __m256 pos1 = _mm256_load_ps(&X[row_index[k + 2] * M_ROW + i]);
                __m256 neg1 = _mm256_load_ps(&X[row_index[k + 3] * M_ROW + i]);
                __m256 pos2 = _mm256_load_ps(&X[row_index[k + 4] * M_ROW + i]);
                __m256 neg2 = _mm256_load_ps(&X[row_index[k + 5] * M_ROW + i]);
                __m256 pos3 = _mm256_load_ps(&X[row_index[k + 6] * M_ROW + i]);
                __m256 neg3 = _mm256_load_ps(&X[row_index[k + 7] * M_ROW + i]);
                res0 = _mm256_add_ps(res0, _mm256_sub_ps(pos0, neg0));
                res1 = _mm256_add_ps(res1, _mm256_sub_ps(pos1, neg1));
                res2 = _mm256_add_ps(res2, _mm256_sub_ps(pos2, neg2));
                res3 = _mm256_add_ps(res3, _mm256_sub_ps(pos3, neg3));
            }

            __m256 pos0 = _mm256_set1_ps(0.0f);
            for (int k = groupData[1]; k < groupData[2]; k++) {
                pos0 = _mm256_add_ps(pos0, _mm256_load_ps(&X[row_index[k] * M_ROW + i]));
            }
            res0 = _mm256_add_ps(res0, pos0);
            __m256 neg0 = _mm256_set1_ps(0.0f);
            for (int k = groupData[2]; k < groupData[3]; k++) {
                neg0 = _mm256_add_ps(neg0, _mm256_load_ps(&X[row_index[k] * M_ROW + i]));
            }
            res0 = _mm256_sub_ps(res0, neg0);

            __m256 pos1 = _mm256_set1_ps(0.0f);
            for (int k = groupData[3]; k < groupData[4]; k++) {
                pos1 = _mm256_add_ps(pos1, _mm256_load_ps(&X[row_index[k] * M_ROW + i]));
            }
            res1 = _mm256_add_ps(res1, pos1);
            __m256 neg1 = _mm256_set1_ps(0.0f);
            for (int k = groupData[4]; k < groupData[5]; k++) {
                neg1 = _mm256_add_ps(neg1, _mm256_load_ps(&X[row_index[k] * M_ROW + i]));
            }
            res1 = _mm256_sub_ps(res1, neg1);

            __m256 pos2 = _mm256_set1_ps(0.0f);
            for (int k = groupData[5]; k < groupData[6]; k++) {
                pos2 = _mm256_add_ps(pos2, _mm256_load_ps(&X[row_index[k] * M_ROW + i]));
            }
            res2 = _mm256_add_ps(res2, pos2);
            __m256 neg2 = _mm256_set1_ps(0.0f);
            for (int k = groupData[6]; k < groupData[7]; k++) {
                neg2 = _mm256_add_ps(neg2, _mm256_load_ps(&X[row_index[k] * M_ROW + i]));
            }
            res2 = _mm256_sub_ps(res2, neg2);

            __m256 pos3 = _mm256_set1_ps(0.0f);
            for (int k = groupData[7]; k < groupData[8]; k++) {
                pos3 = _mm256_add_ps(pos3, _mm256_load_ps(&X[row_index[k] * M_ROW + i]));
            }
            res3 = _mm256_add_ps(res3, pos3);
            __m256 neg3 = _mm256_set1_ps(0.0f);
            for (int k = groupData[8]; k < groupData[9]; k++) {
                neg3 = _mm256_add_ps(neg3, _mm256_load_ps(&X[row_index[k] * M_ROW + i]));
            }
            res3 = _mm256_sub_ps(res3, neg3);

            _mm256_store_ps(&result[(j * 4 + 0) * M_ROW + i], res0);
            _mm256_store_ps(&result[(j * 4 + 1) * M_ROW + i], res1);
            _mm256_store_ps(&result[(j * 4 + 2) * M_ROW + i], res2);
            _mm256_store_ps(&result[(j * 4 + 3) * M_ROW + i], res3);
        }
    }
}

In [17]:
# Generate Uniform TCSC colMajor
M_UNROLL = 32
N_UNROLL = 32
X_DATA_BYTES = 4
SIMD_SIZE = 8
UNROLL_REMAIN = True
M_SIMD = int(M_UNROLL/SIMD_SIZE)
function_name = "void GEMM_CPU_FP32_colMajor_TCSC_Uniform_"+str(M_UNROLL)+"xG"+str(N_UNROLL)+"_AVX2_OpenMP"
function_name +="(float* X, int32_t NonZeroPerCol, int16_t* row_index, float* result, int M_ROW, int N_COL, int K){"
print(function_name)
print("    #pragma omp parallel for")
print("    for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {") 
print("        for (int i = 0; i < M_ROW; i +="+str(M_UNROLL)+") {")
for sn in range(N_UNROLL):
    for sm in range(M_SIMD):
        print("            __m256 res"+str(sn)+str(sm)+" = _mm256_setzero_ps();")    
print("            for (int k = j * "+str(N_UNROLL)+" * NonZeroPerCol; k < (j + 1) * "+str(N_UNROLL)+" * NonZeroPerCol; k += "+str(2 * N_UNROLL)+") {")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
        print("                __m256 pos"+str(sn)+str(sm)+" = _mm256_load_ps(X + row_index[k + "+str(sn*2+0)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+");")   
        print("                __m256 neg"+str(sn)+str(sm)+" = _mm256_load_ps(X + row_index[k + "+str(sn*2+1)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+");")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
        idx = str(sn)+str(sm)
        print("                res"+idx+" = _mm256_add_ps(res"+idx+", _mm256_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
       print("            _mm256_store_ps(result + (j * "+str(N_UNROLL)+" + "+str(sn)+") * M_ROW  + i + "+str(sm*SIMD_SIZE)+", res"+str(sn)+str(sm)+");") 

print("        }")
print("    }")
print("}")

void GEMM_CPU_FP32_colMajor_TCSC_Uniform_32xG32_AVX2_OpenMP(float* X, int32_t NonZeroPerCol, int16_t* row_index, float* result, int M_ROW, int N_COL, int K){
    #pragma omp parallel for
    for (int j = 0; j < N_COL / 32; j++) {
        for (int i = 0; i < M_ROW; i +=32) {
            __m256 res00 = _mm256_setzero_ps();
            __m256 res01 = _mm256_setzero_ps();
            __m256 res02 = _mm256_setzero_ps();
            __m256 res03 = _mm256_setzero_ps();
            __m256 res10 = _mm256_setzero_ps();
            __m256 res11 = _mm256_setzero_ps();
            __m256 res12 = _mm256_setzero_ps();
            __m256 res13 = _mm256_setzero_ps();
            __m256 res20 = _mm256_setzero_ps();
            __m256 res21 = _mm256_setzero_ps();
            __m256 res22 = _mm256_setzero_ps();
            __m256 res23 = _mm256_setzero_ps();
            __m256 res30 = _mm256_setzero_ps();
            __m256 res31 = _mm256_setzero_ps();
            __m256 res32 = _mm256_setzero_ps();
     

In [None]:
void GEMM_CPU_FP32_colMajor_TCSC_Uniform_8xG4_AVX2_OpenMP(float* X, int32_t NonZeroPerCol, int16_t* row_index, float* result, int M_ROW, int N_COL, int K) {
#pragma omp parallel for
    for (int j = 0; j < N_COL/4; j ++) {
        /* Pointer to where a column starts and ends
        int align_start  = metadata[j + 0];
        int align_end    = metadata[j + 1];
        */
        // Group # = j, metadata per group = 2
        for (int i = 0; i < M_ROW; i += 8) {
            __m256 res0 = _mm256_set1_ps(0.0f);
            __m256 res1 = _mm256_set1_ps(0.0f);
            __m256 res2 = _mm256_set1_ps(0.0f);
            __m256 res3 = _mm256_set1_ps(0.0f);
            for (int k = j * 4 * NonZeroPerCol; k < (j * 4 + 4)*NonZeroPerCol; k += 8) {
                __m256 pos0 = _mm256_load_ps(&X[row_index[k + 0] * M_ROW + i]);
                __m256 neg0 = _mm256_load_ps(&X[row_index[k + 1] * M_ROW + i]);
                __m256 pos1 = _mm256_load_ps(&X[row_index[k + 2] * M_ROW + i]);
                __m256 neg1 = _mm256_load_ps(&X[row_index[k + 3] * M_ROW + i]);
                __m256 pos2 = _mm256_load_ps(&X[row_index[k + 4] * M_ROW + i]);
                __m256 neg2 = _mm256_load_ps(&X[row_index[k + 5] * M_ROW + i]);
                __m256 pos3 = _mm256_load_ps(&X[row_index[k + 6] * M_ROW + i]);
                __m256 neg3 = _mm256_load_ps(&X[row_index[k + 7] * M_ROW + i]);
                res0 = _mm256_add_ps(res0, _mm256_sub_ps(pos0, neg0));
                res1 = _mm256_add_ps(res1, _mm256_sub_ps(pos1, neg1));
                res2 = _mm256_add_ps(res2, _mm256_sub_ps(pos2, neg2));
                res3 = _mm256_add_ps(res3, _mm256_sub_ps(pos3, neg3));
            }
            _mm256_store_ps(&result[(j * 4 + 0) * M_ROW + i], res0);
            _mm256_store_ps(&result[(j * 4 + 1) * M_ROW + i], res1);
            _mm256_store_ps(&result[(j * 4 + 2) * M_ROW + i], res2);
            _mm256_store_ps(&result[(j * 4 + 3) * M_ROW + i], res3);
        }
    }
}


# AVX-512 Generator

In [20]:
# Generate GEMM_CPU_FP32_colMajor_TCSC_Merged_GroupMin - AVX-512
M_UNROLL = 32
N_UNROLL = 32
X_DATA_BYTES = 4
SIMD_SIZE = 16
UNROLL_REMAIN = True
M_SIMD = int(M_UNROLL/SIMD_SIZE)
function_name = "void GEMM_CPU_FP32_colMajor_TCSC_Merged_GroupMin_"+str(M_UNROLL)+"xG"+str(N_UNROLL)+"_AVX512_OpenMP"
function_name +="(float* X, int32_t* metadata, int16_t* row_index, float* result, int M_ROW, int N_COL, int K){"
print(function_name)
print("    #pragma omp parallel for")
print("    for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("        int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
print("        for (int i = 0; i < M_ROW; i +="+str(M_UNROLL)+") {")
for sn in range(N_UNROLL):
    for sm in range(M_SIMD):
        print("            __m512 res"+str(sn)+str(sm)+" = _mm512_setzero_ps();")    
print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
        print("                __m512 pos"+str(sn)+str(sm)+" = _mm512_load_ps(X + row_index[k + "+str(sn*2+0)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+");")   
        print("                __m512 neg"+str(sn)+str(sm)+" = _mm512_load_ps(X + row_index[k + "+str(sn*2+1)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+");")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
        idx = str(sn)+str(sm)
        print("                res"+idx+" = _mm512_add_ps(res"+idx+", _mm512_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")

for sn in range(N_UNROLL):
    print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++) {")
    for sm in range(M_SIMD):
        print("                    res"+str(sn)+str(sm)+" = _mm512_add_ps(res"+str(sn)+str(sm)+", _mm512_load_ps(X + row_index[k] * M_ROW + i + "+str(sm*SIMD_SIZE)+"));")
    print("                }")
    print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++) {")
    for sm in range(M_SIMD):
        print("                    res"+str(sn)+str(sm)+" = _mm512_sub_ps(res"+str(sn)+str(sm)+", _mm512_load_ps(X + row_index[k] * M_ROW + i + "+str(sm*SIMD_SIZE)+"));")
    print("                }")
    
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
       print("            _mm512_store_ps(result + (j * "+str(N_UNROLL)+" + "+str(sn)+") * M_ROW  + i + "+str(sm*SIMD_SIZE)+", res"+str(sn)+str(sm)+");") 

print("        }")
print("    }")
print("}")

void GEMM_CPU_FP32_colMajor_TCSC_Merged_GroupMin_32xG32_AVX512_OpenMP(float* X, int32_t* metadata, int16_t* row_index, float* result, int M_ROW, int N_COL, int K){
    #pragma omp parallel for
    for (int j = 0; j < N_COL / 32; j++) {
        int* groupData = &metadata[j * 66]; 
        for (int i = 0; i < M_ROW; i +=32) {
            __m512 res00 = _mm512_setzero_ps();
            __m512 res01 = _mm512_setzero_ps();
            __m512 res10 = _mm512_setzero_ps();
            __m512 res11 = _mm512_setzero_ps();
            __m512 res20 = _mm512_setzero_ps();
            __m512 res21 = _mm512_setzero_ps();
            __m512 res30 = _mm512_setzero_ps();
            __m512 res31 = _mm512_setzero_ps();
            __m512 res40 = _mm512_setzero_ps();
            __m512 res41 = _mm512_setzero_ps();
            __m512 res50 = _mm512_setzero_ps();
            __m512 res51 = _mm512_setzero_ps();
            __m512 res60 = _mm512_setzero_ps();
            __m512 res61 = _mm512_setzero_ps();
  

In [24]:
# Generate Uniform TCSC colMajor - AVX-512
M_UNROLL = 64
N_UNROLL = 4
X_DATA_BYTES = 4
SIMD_SIZE = 16
UNROLL_REMAIN = True
M_SIMD = int(M_UNROLL/SIMD_SIZE)
function_name = "void GEMM_CPU_FP32_colMajor_TCSC_Uniform_"+str(M_UNROLL)+"xG"+str(N_UNROLL)+"_AVX512_OpenMP"
function_name +="(float* X, int32_t NonZeroPerCol, int16_t* row_index, float* result, int M_ROW, int N_COL, int K){"
print(function_name)
print("    #pragma omp parallel for")
print("    for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {") 
print("        for (int i = 0; i < M_ROW; i +="+str(M_UNROLL)+") {")
for sn in range(N_UNROLL):
    for sm in range(M_SIMD):
        print("            __m512 res"+str(sn)+str(sm)+" = _mm512_setzero_ps();")    
print("            for (int k = j * "+str(N_UNROLL)+" * NonZeroPerCol; k < (j + 1) * "+str(N_UNROLL)+" * NonZeroPerCol; k += "+str(2 * N_UNROLL)+") {")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
        print("                __m512 pos"+str(sn)+str(sm)+" = _mm512_load_ps(X + row_index[k + "+str(sn*2+0)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+");")   
        print("                __m512 neg"+str(sn)+str(sm)+" = _mm512_load_ps(X + row_index[k + "+str(sn*2+1)+"] * M_ROW + i + "+str(sm*SIMD_SIZE)+");")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
        idx = str(sn)+str(sm)
        print("                res"+idx+" = _mm512_add_ps(res"+idx+", _mm512_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
for sm in range(M_SIMD):
    for sn in range(N_UNROLL):    
       print("            _mm512_store_ps(result + (j * "+str(N_UNROLL)+" + "+str(sn)+") * M_ROW  + i + "+str(sm*SIMD_SIZE)+", res"+str(sn)+str(sm)+");") 

print("        }")
print("    }")
print("}")

void GEMM_CPU_FP32_colMajor_TCSC_Uniform_64xG4_AVX512_OpenMP(float* X, int32_t NonZeroPerCol, int16_t* row_index, float* result, int M_ROW, int N_COL, int K){
    #pragma omp parallel for
    for (int j = 0; j < N_COL / 4; j++) {
        for (int i = 0; i < M_ROW; i +=64) {
            __m512 res00 = _mm512_setzero_ps();
            __m512 res01 = _mm512_setzero_ps();
            __m512 res02 = _mm512_setzero_ps();
            __m512 res03 = _mm512_setzero_ps();
            __m512 res10 = _mm512_setzero_ps();
            __m512 res11 = _mm512_setzero_ps();
            __m512 res12 = _mm512_setzero_ps();
            __m512 res13 = _mm512_setzero_ps();
            __m512 res20 = _mm512_setzero_ps();
            __m512 res21 = _mm512_setzero_ps();
            __m512 res22 = _mm512_setzero_ps();
            __m512 res23 = _mm512_setzero_ps();
            __m512 res30 = _mm512_setzero_ps();
            __m512 res31 = _mm512_setzero_ps();
            __m512 res32 = _mm512_setzero_ps();
     

In [37]:
# Generate GEMM_CPU_FP32_rowMajor_TCSC_Merged_GroupMin - AVX-512
M_UNROLL = 2
N_UNROLL = 32
X_DATA_BYTES = 4
SIMD_SIZE = 16
UNROLL_REMAIN = False
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMM_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_"+str(M_UNROLL)+"xG"+str(N_UNROLL)+"_AVX512_OpenMP"
function_name +="(float* X, int32_t* metadata, int32_t* row_index, float* result, int M_ROW, int N_COL, int K){"
print(function_name)
print("    #pragma omp parallel for")
print("    for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("        for (int i = 0; i < M_ROW; i +="+str(M_UNROLL)+") {")
print("            float* Xbase = X + i * K;")
print("            int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
for sn in range(N_SIMD):
    for sm in range(M_UNROLL):
        print("            __m512 res"+str(sn)+str(sm)+" = _mm512_setzero_ps();")
    

print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m512i indices_p"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE)+"));") 
for sn in range(N_SIMD): 
    print("                __m512i indices_n"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE + SIMD_SIZE*N_SIMD)+"));")
for sm in range(M_UNROLL):
    for sn in range(N_SIMD):    
        print("                __m512 pos"+str(sn)+str(sm)+" = _mm512_i32gather_ps(Xbase + K * "+str(sm)+", indices_p"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sm in range(M_UNROLL):
    for sn in range(N_SIMD):    
        print("                __m512 neg"+str(sn)+str(sm)+" = _mm512_i32gather_ps(Xbase + K * "+str(sm)+", indices_n"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sm in range(M_UNROLL):
    for sn in range(N_SIMD):    
        idx = str(sn)+str(sm)
        print("                res"+idx+" = _mm512_add_ps(res"+idx+", _mm512_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + i * N_COL + j * "+str(N_UNROLL)+";")
for sm in range(M_UNROLL):
    for sn in range(N_SIMD):    
       print("            _mm512_store_ps(Ybase + N_COL * "+str(sm)+" + "+str(SIMD_SIZE*sn)+", res"+str(sn)+str(sm)+");") 


if(UNROLL_REMAIN):
    for sn in range(N_UNROLL):
        print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++) {")
        for sm in range(M_UNROLL):
            print("                    Ybase[N_COL * "+str(sm)+" + "+str(sn)+"] += Xbase[K * "+str(sm)+" + row_index[k]];")
        print("                }")
        print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++) {")
        for sm in range(M_UNROLL):
            print("                    Ybase[N_COL * "+str(sm)+" + "+str(sn)+"] -= Xbase[K * "+str(sm)+" + row_index[k]];")
        print("                }")
else:
    print("            for (int g = 0; g < "+str(N_UNROLL)+"; g++) {")
    print("                for (int k = groupData[2 * g + 1]; k < groupData[2 * g + 2]; k++) {")
    for sm in range(M_UNROLL):
        print("                    Ybase[N_COL * "+str(sm)+" + g] += Xbase[K * "+str(sm)+" + row_index[k]];")
    print("                }")
    print("                for (int k = groupData[2 * g + 2]; k < groupData[2 * g + 3]; k++) {")
    for sm in range(M_UNROLL):
        print("                    Ybase[N_COL * "+str(sm)+" + g] -= Xbase[K * "+str(sm)+" + row_index[k]];")
    print("                }")
    print("            }")
print("        }")
print("    }")
print("}")

void GEMM_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_2xG32_AVX512_OpenMP(float* X, int32_t* metadata, int32_t* row_index, float* result, int M_ROW, int N_COL, int K){
    #pragma omp parallel for
    for (int j = 0; j < N_COL / 32; j++) {
        for (int i = 0; i < M_ROW; i +=2) {
            float* Xbase = X + i * K;
            int* groupData = &metadata[j * 66]; 
            __m512 res00 = _mm512_setzero_ps();
            __m512 res01 = _mm512_setzero_ps();
            __m512 res10 = _mm512_setzero_ps();
            __m512 res11 = _mm512_setzero_ps();
            for (int k = groupData[0]; k < groupData[1]; k += 64) {
                __m512i indices_p0 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 0));
                __m512i indices_p1 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 16));
                __m512i indices_n0 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 32));
                __m512i indices_n1 = _mm512_l

In [42]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin - AVX-512
N_UNROLL = 128
X_DATA_BYTES = 4
SIMD_SIZE = 16
UNROLL_REMAIN = False
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G"+str(N_UNROLL)+"_AVX512_OpenMP"
function_name +="(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("            int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
for sn in range(N_SIMD):
    print("            __m512 res"+str(sn)+" = _mm512_setzero_ps();")
print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m512i indices_p"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE)+"));") 
for sn in range(N_SIMD): 
    print("                __m512i indices_n"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE + SIMD_SIZE * N_SIMD)+"));")
for sn in range(N_SIMD): 
    print("                __m512 pos"+str(sn)+" = _mm512_i32gather_ps(indices_p"+str(sn)+", X, "+str(X_DATA_BYTES)+");") 
for sn in range(N_SIMD):  
    print("                __m512 neg"+str(sn)+" = _mm512_i32gather_ps(indices_n"+str(sn)+", X, "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm512_add_ps(res"+idx+", _mm512_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm512_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
if(UNROLL_REMAIN):
    for sn in range(N_UNROLL): 
        print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++)")
        print("                    Ybase["+str(sn)+"] += X[row_index[k]];")
        print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++)")
        print("                    Ybase["+str(sn)+"] -= X[row_index[k]];")
else:
        print("            for (int g = 0; g < "+str(N_UNROLL)+"; g++) {")
        print("                for (int k = groupData[2 * g + 1]; k < groupData[2 * g + 2]; k++)")
        print("                    Ybase[g] += X[row_index[k]];")
        print("                for (int k = groupData[2 * g + 2]; k < groupData[2 * g + 3]; k++)")
        print("                    Ybase[g] -= X[row_index[k]];")
        print("            }")
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G128_AVX512_OpenMP(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 128; j++) {
            int* groupData = &metadata[j * 258]; 
            __m512 res0 = _mm512_setzero_ps();
            __m512 res1 = _mm512_setzero_ps();
            __m512 res2 = _mm512_setzero_ps();
            __m512 res3 = _mm512_setzero_ps();
            __m512 res4 = _mm512_setzero_ps();
            __m512 res5 = _mm512_setzero_ps();
            __m512 res6 = _mm512_setzero_ps();
            __m512 res7 = _mm512_setzero_ps();
            for (int k = groupData[0]; k < groupData[1]; k += 256) {
                __m512i indices_p0 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 0));
                __m512i indices_p1 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 16));
                __m512i indices_p2 = _mm512_load_si512

In [48]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Uniform - AVX-512
N_UNROLL = 128
X_DATA_BYTES = 4
SIMD_SIZE = 16
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Uniform_G"+str(N_UNROLL)+"_AVX512_OpenMP"
function_name +="(float* X, int32_t NonZeroPerCol, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
for sn in range(N_SIMD):
    print("            __m512 res"+str(sn)+" = _mm512_setzero_ps();")
print("            for (int k = j * "+str(N_UNROLL)+" * NonZeroPerCol; k < (j + 1) * "+str(N_UNROLL)+" * NonZeroPerCol; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m512i indices_p"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE)+"));") 
for sn in range(N_SIMD): 
    print("                __m512i indices_n"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE + SIMD_SIZE * N_SIMD)+"));")
for sn in range(N_SIMD): 
    print("                __m512 pos"+str(sn)+" = _mm512_i32gather_ps(indices_p"+str(sn)+", X, "+str(X_DATA_BYTES)+");") 
for sn in range(N_SIMD):  
    print("                __m512 neg"+str(sn)+" = _mm512_i32gather_ps(indices_n"+str(sn)+", X, "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm512_add_ps(res"+idx+", _mm512_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm512_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Uniform_G128_AVX512_OpenMP(float* X, int32_t NonZeroPerCol, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 128; j++) {
            __m512 res0 = _mm512_setzero_ps();
            __m512 res1 = _mm512_setzero_ps();
            __m512 res2 = _mm512_setzero_ps();
            __m512 res3 = _mm512_setzero_ps();
            __m512 res4 = _mm512_setzero_ps();
            __m512 res5 = _mm512_setzero_ps();
            __m512 res6 = _mm512_setzero_ps();
            __m512 res7 = _mm512_setzero_ps();
            for (int k = j * 128 * NonZeroPerCol; k < (j + 1) * 128 * NonZeroPerCol; k += 256) {
                __m512i indices_p0 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 0));
                __m512i indices_p1 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 16));
                __m512i indices_p2 = _mm512_load_si512(reinterpret_cast<const __

# GEMV Generator Contineous with SIMD_SIZE

In [8]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin
N_UNROLL = 128
X_DATA_BYTES = 4
SIMD_SIZE = 8
UNROLL_REMAIN = False
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G"+str(N_UNROLL)+"_CS"+str(SIMD_SIZE)+"_AVX2_OpenMP"
function_name +="(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("            int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
for sn in range(N_SIMD):
    print("            __m256 res"+str(sn)+" = _mm256_setzero_ps();")
print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m256i indices_p"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * SIMD_SIZE * 2)+"));") 
    print("                __m256i indices_n"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * SIMD_SIZE * 2 + SIMD_SIZE)+"));")
for sn in range(N_SIMD): 
    print("                __m256 pos"+str(sn)+" = _mm256_i32gather_ps(X, indices_p"+str(sn)+", "+str(X_DATA_BYTES)+");") 
    print("                __m256 neg"+str(sn)+" = _mm256_i32gather_ps(X, indices_n"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm256_add_ps(res"+idx+", _mm256_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm256_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
if(UNROLL_REMAIN):
    for sn in range(N_UNROLL): 
        print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++)")
        print("                    Ybase["+str(sn)+"] += X[row_index[k]];")
        print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++)")
        print("                    Ybase["+str(sn)+"] -= X[row_index[k]];")
else:
        #print("            #pragma omp simd")
        print("            for (int g = 0; g < "+str(N_UNROLL)+"; g++) {")
        #print("                #pragma omp simd")
        print("                for (int k = groupData[2 * g + 1]; k < groupData[2 * g + 2]; k++)")
        print("                    Ybase[g] += X[row_index[k]];")
        #print("                #pragma omp simd")
        print("                for (int k = groupData[2 * g + 2]; k < groupData[2 * g + 3]; k++)")
        print("                    Ybase[g] -= X[row_index[k]];")
        print("            }")
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G128_CS8_AVX2_OpenMP(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 128; j++) {
            int* groupData = &metadata[j * 258]; 
            __m256 res0 = _mm256_setzero_ps();
            __m256 res1 = _mm256_setzero_ps();
            __m256 res2 = _mm256_setzero_ps();
            __m256 res3 = _mm256_setzero_ps();
            __m256 res4 = _mm256_setzero_ps();
            __m256 res5 = _mm256_setzero_ps();
            __m256 res6 = _mm256_setzero_ps();
            __m256 res7 = _mm256_setzero_ps();
            __m256 res8 = _mm256_setzero_ps();
            __m256 res9 = _mm256_setzero_ps();
            __m256 res10 = _mm256_setzero_ps();
            __m256 res11 = _mm256_setzero_ps();
            __m256 res12 = _mm256_setzero_ps();
            __m256 res13 = _mm256_setzero_ps();
            __m256 res14 = _mm256_setzero_ps();
      

In [9]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Uniform
N_UNROLL = 128
X_DATA_BYTES = 4
SIMD_SIZE = 8
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Uniform_G"+str(N_UNROLL)+"_CS"+str(SIMD_SIZE)+"_AVX2_OpenMP"
function_name +="(float* X, int32_t NonZeroPerCol, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
for sn in range(N_SIMD):
    print("            __m256 res"+str(sn)+" = _mm256_setzero_ps();")
print("            for (int k = j * "+str(N_UNROLL)+" * NonZeroPerCol; k < (j + 1) * "+str(N_UNROLL)+" * NonZeroPerCol; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m256i indices_p"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * SIMD_SIZE * 2)+"));") 
    print("                __m256i indices_n"+str(sn)+" = _mm256_load_si256(reinterpret_cast<const __m256i*>(row_index + k + "+str(sn * SIMD_SIZE * 2 + SIMD_SIZE)+"));")
for sn in range(N_SIMD): 
    print("                __m256 pos"+str(sn)+" = _mm256_i32gather_ps(X, indices_p"+str(sn)+", "+str(X_DATA_BYTES)+");") 
    print("                __m256 neg"+str(sn)+" = _mm256_i32gather_ps(X, indices_n"+str(sn)+", "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm256_add_ps(res"+idx+", _mm256_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm256_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Uniform_G128_CS8_AVX2_OpenMP(float* X, int32_t NonZeroPerCol, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 128; j++) {
            __m256 res0 = _mm256_setzero_ps();
            __m256 res1 = _mm256_setzero_ps();
            __m256 res2 = _mm256_setzero_ps();
            __m256 res3 = _mm256_setzero_ps();
            __m256 res4 = _mm256_setzero_ps();
            __m256 res5 = _mm256_setzero_ps();
            __m256 res6 = _mm256_setzero_ps();
            __m256 res7 = _mm256_setzero_ps();
            __m256 res8 = _mm256_setzero_ps();
            __m256 res9 = _mm256_setzero_ps();
            __m256 res10 = _mm256_setzero_ps();
            __m256 res11 = _mm256_setzero_ps();
            __m256 res12 = _mm256_setzero_ps();
            __m256 res13 = _mm256_setzero_ps();
            __m256 res14 = _mm256_setzero_ps();
            __m256 res15 = _mm256_setzero_ps();
            

In [13]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin - AVX-512
N_UNROLL = 128
X_DATA_BYTES = 4
SIMD_SIZE = 16
UNROLL_REMAIN = False
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G"+str(N_UNROLL)+"_CS"+str(SIMD_SIZE)+"_AVX512_OpenMP"
function_name +="(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
print("            int* groupData = &metadata[j * "+str(2 + 2 * N_UNROLL)+"]; ")
for sn in range(N_SIMD):
    print("            __m512 res"+str(sn)+" = _mm512_setzero_ps();")
print("            for (int k = groupData[0]; k < groupData[1]; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m512i indices_p"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE * 2)+"));") 
    print("                __m512i indices_n"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE * 2 + SIMD_SIZE)+"));")
for sn in range(N_SIMD): 
    print("                __m512 pos"+str(sn)+" = _mm512_i32gather_ps(indices_p"+str(sn)+", X, "+str(X_DATA_BYTES)+");") 
    print("                __m512 neg"+str(sn)+" = _mm512_i32gather_ps(indices_n"+str(sn)+", X, "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm512_add_ps(res"+idx+", _mm512_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm512_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
if(UNROLL_REMAIN):
    for sn in range(N_UNROLL): 
        print("                for (int k = groupData["+str(2*sn+1)+"]; k < groupData["+str(2 * sn+2)+"]; k++)")
        print("                    Ybase["+str(sn)+"] += X[row_index[k]];")
        print("                for (int k = groupData["+str(2*sn+2)+"]; k < groupData["+str(2 * sn+3)+"]; k++)")
        print("                    Ybase["+str(sn)+"] -= X[row_index[k]];")
else:
        print("            for (int g = 0; g < "+str(N_UNROLL)+"; g++) {")
        #print("                #pragma omp simd")
        print("                for (int k = groupData[2 * g + 1]; k < groupData[2 * g + 2]; k++)")
        print("                    Ybase[g] += X[row_index[k]];")
        #print("                #pragma omp simd")
        print("                for (int k = groupData[2 * g + 2]; k < groupData[2 * g + 3]; k++)")
        print("                    Ybase[g] -= X[row_index[k]];")
        print("            }")
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Merged_GroupMin_G128_CS16_AVX512_OpenMP(float* X, int32_t* metadata, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 128; j++) {
            int* groupData = &metadata[j * 258]; 
            __m512 res0 = _mm512_setzero_ps();
            __m512 res1 = _mm512_setzero_ps();
            __m512 res2 = _mm512_setzero_ps();
            __m512 res3 = _mm512_setzero_ps();
            __m512 res4 = _mm512_setzero_ps();
            __m512 res5 = _mm512_setzero_ps();
            __m512 res6 = _mm512_setzero_ps();
            __m512 res7 = _mm512_setzero_ps();
            for (int k = groupData[0]; k < groupData[1]; k += 256) {
                __m512i indices_p0 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 0));
                __m512i indices_n0 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 16));
                __m512i indices_p1 = _mm512_load_

In [16]:
# Generate GEMV_CPU_FP32_rowMajor_TCSC_Uniform - AVX-512
N_UNROLL = 128
X_DATA_BYTES = 4
SIMD_SIZE = 16
N_SIMD = int(N_UNROLL/SIMD_SIZE)
function_name = "void GEMV_CPU_FP32_rowMajor_TCSC_Uniform_G"+str(N_UNROLL)+"_CS"+str(SIMD_SIZE)+"_AVX512_OpenMP"
function_name +="(float* X, int32_t NonZeroPerCol, int32_t* row_index, float* result, int N_COL, int K){"
print(function_name)
print("        #pragma omp parallel for")
print("        for (int j = 0; j < N_COL / "+str(N_UNROLL)+"; j++) {")
for sn in range(N_SIMD):
    print("            __m512 res"+str(sn)+" = _mm512_setzero_ps();")
print("            for (int k = j * "+str(N_UNROLL)+" * NonZeroPerCol; k < (j + 1) * "+str(N_UNROLL)+" * NonZeroPerCol; k += "+str(2 * N_UNROLL)+") {")
for sn in range(N_SIMD):
    print("                __m512i indices_p"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE * 2)+"));") 
    print("                __m512i indices_n"+str(sn)+" = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + "+str(sn * SIMD_SIZE * 2+ SIMD_SIZE)+"));")
for sn in range(N_SIMD): 
    print("                __m512 pos"+str(sn)+" = _mm512_i32gather_ps(indices_p"+str(sn)+", X, "+str(X_DATA_BYTES)+");")  
    print("                __m512 neg"+str(sn)+" = _mm512_i32gather_ps(indices_n"+str(sn)+", X, "+str(X_DATA_BYTES)+");")
for sn in range(N_SIMD):    
    idx = str(sn)
    print("                res"+idx+" = _mm512_add_ps(res"+idx+", _mm512_sub_ps(pos"+idx+", neg"+idx+"));")
print("            }")
print("            float* Ybase = result + j * "+str(N_UNROLL)+";")
for sn in range(N_SIMD):    
     print("            _mm512_store_ps(Ybase + "+str(SIMD_SIZE*sn)+", res"+str(sn)+");") 
print("        }")
print("}")

void GEMV_CPU_FP32_rowMajor_TCSC_Uniform_G128_CS16_AVX512_OpenMP(float* X, int32_t NonZeroPerCol, int32_t* row_index, float* result, int N_COL, int K){
        #pragma omp parallel for
        for (int j = 0; j < N_COL / 128; j++) {
            __m512 res0 = _mm512_setzero_ps();
            __m512 res1 = _mm512_setzero_ps();
            __m512 res2 = _mm512_setzero_ps();
            __m512 res3 = _mm512_setzero_ps();
            __m512 res4 = _mm512_setzero_ps();
            __m512 res5 = _mm512_setzero_ps();
            __m512 res6 = _mm512_setzero_ps();
            __m512 res7 = _mm512_setzero_ps();
            for (int k = j * 128 * NonZeroPerCol; k < (j + 1) * 128 * NonZeroPerCol; k += 256) {
                __m512i indices_p0 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 0));
                __m512i indices_n0 = _mm512_load_si512(reinterpret_cast<const __m512i*>(row_index + k + 16));
                __m512i indices_p1 = _mm512_load_si512(reinterpret_cast<con