Skip to content

Commit

Permalink
Change dotaxpyf microkernel signature.
Browse files Browse the repository at this point in the history
The function signature for dotaxpyf has been changed to allow different `alpha` values for the dot and axpy sub-problems. This is needed to support skew-symmetric operations which differ in more than just conjugation of A and A^T.
  • Loading branch information
devinamatthews committed Apr 26, 2024
1 parent 3a52b71 commit b986782
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 206 deletions.
19 changes: 11 additions & 8 deletions config/template/kernels/1f/bli_dotxaxpyf_template_noopt_var1.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ void bli_zdotxaxpyf_template_noopt
conj_t conjx,
dim_t m,
dim_t b_n,
dcomplex* restrict alpha,
dcomplex* restrict alphaw,
dcomplex* restrict alphax,
dcomplex* restrict a, inc_t inca, inc_t lda,
dcomplex* restrict w, inc_t incw,
dcomplex* restrict x, inc_t incx,
Expand All @@ -67,8 +68,8 @@ void bli_zdotxaxpyf_template_noopt
This kernel performs the following two gemv-like operations:
y := beta * y + alpha * conjat( A^T ) * conjw( w )
z := z + alpha * conja( A ) * conjx( x )
y := beta * y + alphaw * conjat( A^T ) * conjw( w )
z := z + alphax * conja( A ) * conjx( x )
where A is an m x b_n matrix, x and y are vector of length b_n, w and z
are vectors of length m, and alpha and beta are scalars. The operation
Expand All @@ -84,7 +85,8 @@ void bli_zdotxaxpyf_template_noopt
- m: The number of rows in matrix A.
- b_n: The number of columns in matrix A. Must be equal to or less than
the fusing factor.
- alpha: The address of the scalar to be applied to A^T*w and A*x.
- alphaw: The address of the scalar to be applied to A^T*w.
- alphax: The address of the scalar to be applied to A*x.
- a: The address of matrix A.
- inca: The row stride of A. inca should be unit unless the
implementation makes special accomodation for non-unit values.
Expand Down Expand Up @@ -205,7 +207,8 @@ void bli_zdotxaxpyf_template_noopt
conjx,
m,
b_n,
alpha,
alphaw,
alphax,
a, inca, lda,
w, incw,
x, incx,
Expand Down Expand Up @@ -239,15 +242,15 @@ void bli_zdotxaxpyf_template_noopt
for ( j = 0; j < b_n; ++j )
{
bli_zcopys( *xp[ j ], alpha_x[ j ] );
bli_zscals( *alpha, alpha_x[ j ] );
bli_zscals( *alphax, alpha_x[ j ] );
}
}
else // if ( bli_is_conj( conjx ) )
{
for ( j = 0; j < b_n; ++j )
{
bli_zcopyjs( *xp[ j ], alpha_x[ j ] );
bli_zscals( *alpha, alpha_x[ j ] );
bli_zscals( *alphax, alpha_x[ j ] );
}
}

Expand Down Expand Up @@ -468,7 +471,7 @@ void bli_zdotxaxpyf_template_noopt
for ( j = 0; j < b_n; ++j )
{
bli_zscals( *beta, *yp[ j ] );
bli_zaxpys( *alpha, At_w[ j ], *yp[ j ] );
bli_zaxpys( *alphaw, At_w[ j ], *yp[ j ] );
}
}

18 changes: 14 additions & 4 deletions frame/1f/bli_l1f_check.c
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ void bli_dotaxpyv_check

void bli_dotxaxpyf_check
(
const obj_t* alpha,
const obj_t* alphaw,
const obj_t* alphax,
const obj_t* at,
const obj_t* a,
const obj_t* w,
Expand All @@ -302,7 +303,10 @@ void bli_dotxaxpyf_check

// Check object datatypes.

e_val = bli_check_noninteger_object( alpha );
e_val = bli_check_noninteger_object( alphaw );
bli_check_error_code( e_val );

e_val = bli_check_noninteger_object( alphax );
bli_check_error_code( e_val );

e_val = bli_check_floating_object( at );
Expand Down Expand Up @@ -345,7 +349,10 @@ void bli_dotxaxpyf_check

// Check object dimensions.

e_val = bli_check_scalar_object( alpha );
e_val = bli_check_scalar_object( alphaw );
bli_check_error_code( e_val );

e_val = bli_check_scalar_object( alphax );
bli_check_error_code( e_val );

e_val = bli_check_matrix_object( at );
Expand Down Expand Up @@ -397,7 +404,10 @@ void bli_dotxaxpyf_check

// Check object buffers (for non-NULLness).

e_val = bli_check_object_buffer( alpha );
e_val = bli_check_object_buffer( alphaw );
bli_check_error_code( e_val );

e_val = bli_check_object_buffer( alphax );
bli_check_error_code( e_val );

e_val = bli_check_object_buffer( at );
Expand Down
3 changes: 2 additions & 1 deletion frame/1f/bli_l1f_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ GENTPROT( dotaxpyv )
\
void PASTEMAC(opname,_check) \
( \
const obj_t* alpha, \
const obj_t* alphaw, \
const obj_t* alphax, \
const obj_t* at, \
const obj_t* a, \
const obj_t* w, \
Expand Down
3 changes: 2 additions & 1 deletion frame/1f/bli_l1f_ft.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ typedef void (*PASTECH(ch,opname,EX_SUF,tsuf)) \
conj_t conjx, \
dim_t m, \
dim_t b_n, \
const ctype* alpha, \
const ctype* alphaw, \
const ctype* alphax, \
const ctype* a, inc_t inca, inc_t lda, \
const ctype* w, inc_t incw, \
const ctype* x, inc_t incx, \
Expand Down
3 changes: 2 additions & 1 deletion frame/1f/bli_l1f_ker_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
conj_t conjx, \
dim_t m, \
dim_t b_n, \
const void* alpha, \
const void* alphaw, \
const void* alphax, \
const void* a, inc_t inca, inc_t lda, \
const void* w, inc_t incw, \
const void* x, inc_t incx, \
Expand Down
23 changes: 15 additions & 8 deletions frame/1f/bli_l1f_oapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ GENFRONT( dotaxpyv )
\
void PASTEMAC(opname,EX_SUF) \
( \
const obj_t* alpha, \
const obj_t* alphaw, \
const obj_t* alphax, \
const obj_t* at, \
const obj_t* a, \
const obj_t* w, \
Expand Down Expand Up @@ -288,23 +289,28 @@ void PASTEMAC(opname,EX_SUF) \
void* buf_z = bli_obj_buffer_at_off( z ); \
inc_t inc_z = bli_obj_vector_inc( z ); \
\
void* buf_alpha; \
void* buf_alphaw; \
void* buf_alphax; \
void* buf_beta; \
\
obj_t alpha_local; \
obj_t alphaw_local; \
obj_t alphax_local; \
obj_t beta_local; \
\
if ( bli_error_checking_is_enabled() ) \
PASTEMAC(opname,_check)( alpha, at, a, w, x, beta, y, z ); \
PASTEMAC(opname,_check)( alphaw, alphax, at, a, w, x, beta, y, z ); \
\
/* Create local copy-casts of scalars (and apply internal conjugation
as needed). */ \
bli_obj_scalar_init_detached_copy_of( dt, BLIS_NO_CONJUGATE, \
alpha, &alpha_local ); \
alphaw, &alphaw_local ); \
bli_obj_scalar_init_detached_copy_of( dt, BLIS_NO_CONJUGATE, \
alphax, &alphax_local ); \
bli_obj_scalar_init_detached_copy_of( dt, BLIS_NO_CONJUGATE, \
beta, &beta_local ); \
buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \
buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \
buf_alphaw = bli_obj_buffer_for_1x1( dt, &alphaw_local ); \
buf_alphax = bli_obj_buffer_for_1x1( dt, &alphax_local ); \
buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \
\
/* Support cases where matrix A requires a transposition. */ \
if ( bli_obj_has_trans( a ) ) { bli_swap_incs( &rs_a, &cs_a ); } \
Expand All @@ -322,7 +328,8 @@ void PASTEMAC(opname,EX_SUF) \
conjx, \
m, \
b_n, \
buf_alpha, \
buf_alphaw, \
buf_alphax, \
buf_a, rs_a, cs_a, \
buf_w, inc_w, \
buf_x, inc_x, \
Expand Down
3 changes: 2 additions & 1 deletion frame/1f/bli_l1f_oapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ GENTPROT( dotaxpyv )
\
BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \
( \
const obj_t* alpha, \
const obj_t* alphaw, \
const obj_t* alphax, \
const obj_t* at, \
const obj_t* a, \
const obj_t* w, \
Expand Down
6 changes: 4 additions & 2 deletions frame/1f/bli_l1f_tapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ void PASTEMAC(ch,opname,EX_SUF) \
conj_t conjx, \
dim_t m, \
dim_t b_n, \
const ctype* alpha, \
const ctype* alphaw, \
const ctype* alphax, \
const ctype* a, inc_t inca, inc_t lda, \
const ctype* w, inc_t incw, \
const ctype* x, inc_t incx, \
Expand Down Expand Up @@ -214,7 +215,8 @@ void PASTEMAC(ch,opname,EX_SUF) \
conjx, \
m, \
b_n, \
( ctype* )alpha, \
( ctype* )alphaw, \
( ctype* )alphax, \
( ctype* )a, inca, lda, \
( ctype* )w, incw, \
( ctype* )x, incx, \
Expand Down
3 changes: 2 additions & 1 deletion frame/1f/bli_l1f_tapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ BLIS_EXPORT_BLIS void PASTEMAC(ch,opname,EX_SUF) \
conj_t conjx, \
dim_t m, \
dim_t b_n, \
const ctype* alpha, \
const ctype* alphaw, \
const ctype* alphax, \
const ctype* a, inc_t inca, inc_t lda, \
const ctype* w, inc_t incw, \
const ctype* x, inc_t incx, \
Expand Down
1 change: 1 addition & 0 deletions frame/2/hemv/bli_hemv_unf_var1.c
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ void PASTEMAC(ch,varname) \
n_behind, \
f, \
alpha, \
alpha, \
A10, cs_at, rs_at, \
x0, incx, \
x1, incx, \
Expand Down
1 change: 1 addition & 0 deletions frame/2/hemv/bli_hemv_unf_var3.c
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ void PASTEMAC(ch,varname) \
n_ahead, \
f, \
alpha, \
alpha, \
A21, rs_at, cs_at, \
x2, incx, \
x1, incx, \
Expand Down
31 changes: 17 additions & 14 deletions kernels/penryn/1f/bli_dotxaxpyf_penryn_int.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ void bli_ddotxaxpyf_penryn_int
conj_t conjx,
dim_t m,
dim_t b_n,
const void* alpha,
const void* alphaw,
const void* alphax,
const void* a, inc_t inca, inc_t lda,
const void* w, inc_t incw,
const void* x, inc_t incx,
Expand All @@ -61,13 +62,14 @@ void bli_ddotxaxpyf_penryn_int
const cntx_t* cntx
)
{
const double* restrict alpha_cast = alpha;
const double* restrict a_cast = a;
const double* restrict w_cast = w;
const double* restrict x_cast = x;
const double* restrict beta_cast = beta;
double* restrict y_cast = y;
double* restrict z_cast = z;
const double* restrict alphaw_cast = alphaw;
const double* restrict alphax_cast = alphax;
const double* restrict a_cast = a;
const double* restrict w_cast = w;
const double* restrict x_cast = x;
const double* restrict beta_cast = beta;
double* restrict y_cast = y;
double* restrict z_cast = z;

const dim_t n_elem_per_reg = 2;
const dim_t n_iter_unroll = 2;
Expand Down Expand Up @@ -152,7 +154,8 @@ void bli_ddotxaxpyf_penryn_int
conjx,
m,
b_n,
alpha_cast,
alphaw_cast,
alphax_cast,
a_cast, inca, lda,
w_cast, incw,
x_cast, incx,
Expand Down Expand Up @@ -182,10 +185,10 @@ void bli_ddotxaxpyf_penryn_int
chi2 = *(x_cast + 2*incx);
chi3 = *(x_cast + 3*incx);

PASTEMAC(d,d,scals)( *alpha_cast, chi0 );
PASTEMAC(d,d,scals)( *alpha_cast, chi1 );
PASTEMAC(d,d,scals)( *alpha_cast, chi2 );
PASTEMAC(d,d,scals)( *alpha_cast, chi3 );
PASTEMAC(d,d,scals)( *alphax_cast, chi0 );
PASTEMAC(d,d,scals)( *alphax_cast, chi1 );
PASTEMAC(d,d,scals)( *alphax_cast, chi2 );
PASTEMAC(d,d,scals)( *alphax_cast, chi3 );

PASTEMAC(d,set0s)( rho0 );
PASTEMAC(d,set0s)( rho1 );
Expand Down Expand Up @@ -341,7 +344,7 @@ void bli_ddotxaxpyf_penryn_int
rho1v.d[1] = rho3;

betav.v = _mm_loaddup_pd( ( double* ) beta_cast );
alphav.v = _mm_loaddup_pd( ( double* ) alpha_cast );
alphav.v = _mm_loaddup_pd( ( double* ) alphaw_cast );

psi0v.v = _mm_load_pd( ( double* )(y_cast + 0*n_elem_per_reg ) );
psi1v.v = _mm_load_pd( ( double* )(y_cast + 1*n_elem_per_reg ) );
Expand Down

0 comments on commit b986782

Please sign in to comment.