Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add skew-symmetric BLAS operations #805

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
19 changes: 11 additions & 8 deletions config/template/kernels/1f/bli_dotxaxpyf_template_noopt_var1.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ void bli_zdotxaxpyf_template_noopt
conj_t conjx,
dim_t m,
dim_t b_n,
dcomplex* restrict alpha,
dcomplex* restrict alphaw,
dcomplex* restrict alphax,
dcomplex* restrict a, inc_t inca, inc_t lda,
dcomplex* restrict w, inc_t incw,
dcomplex* restrict x, inc_t incx,
Expand All @@ -67,8 +68,8 @@ void bli_zdotxaxpyf_template_noopt

This kernel performs the following two gemv-like operations:

y := beta * y + alpha * conjat( A^T ) * conjw( w )
z := z + alpha * conja( A ) * conjx( x )
y := beta * y + alphaw * conjat( A^T ) * conjw( w )
z := z + alphax * conja( A ) * conjx( x )

where A is an m x b_n matrix, x and y are vector of length b_n, w and z
are vectors of length m, and alpha and beta are scalars. The operation
Expand All @@ -84,7 +85,8 @@ void bli_zdotxaxpyf_template_noopt
- m: The number of rows in matrix A.
- b_n: The number of columns in matrix A. Must be equal to or less than
the fusing factor.
- alpha: The address of the scalar to be applied to A^T*w and A*x.
- alphaw: The address of the scalar to be applied to A^T*w.
- alphax: The address of the scalar to be applied to A*x.
- a: The address of matrix A.
- inca: The row stride of A. inca should be unit unless the
implementation makes special accomodation for non-unit values.
Expand Down Expand Up @@ -205,7 +207,8 @@ void bli_zdotxaxpyf_template_noopt
conjx,
m,
b_n,
alpha,
alphaw,
alphax,
a, inca, lda,
w, incw,
x, incx,
Expand Down Expand Up @@ -239,15 +242,15 @@ void bli_zdotxaxpyf_template_noopt
for ( j = 0; j < b_n; ++j )
{
bli_zcopys( *xp[ j ], alpha_x[ j ] );
bli_zscals( *alpha, alpha_x[ j ] );
bli_zscals( *alphax, alpha_x[ j ] );
}
}
else // if ( bli_is_conj( conjx ) )
{
for ( j = 0; j < b_n; ++j )
{
bli_zcopyjs( *xp[ j ], alpha_x[ j ] );
bli_zscals( *alpha, alpha_x[ j ] );
bli_zscals( *alphax, alpha_x[ j ] );
}
}

Expand Down Expand Up @@ -468,7 +471,7 @@ void bli_zdotxaxpyf_template_noopt
for ( j = 0; j < b_n; ++j )
{
bli_zscals( *beta, *yp[ j ] );
bli_zaxpys( *alpha, At_w[ j ], *yp[ j ] );
bli_zaxpys( *alphaw, At_w[ j ], *yp[ j ] );
}
}

3 changes: 2 additions & 1 deletion configure
Original file line number Diff line number Diff line change
Expand Up @@ -4608,7 +4608,7 @@ print_usage_plugin()

--disable-examples, --enable-examples

Do not include (created by default) example code for plugin
Do not include (included by default) example code for plugin
registration, kernels, etc.

--disable-templates, --enable-templates
Expand Down Expand Up @@ -5384,3 +5384,4 @@ case ${0##*/} in
plugin_main "$@"
;;
esac

1 change: 1 addition & 0 deletions frame/0/bli_l0_check.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ GENFRONT( sqrtsc )
GENFRONT( sqrtrsc )
GENFRONT( subsc )
GENFRONT( invertsc )
GENFRONT( negsc )


#undef GENFRONT
Expand Down
1 change: 1 addition & 0 deletions frame/0/bli_l0_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ GENTPROT( sqrtsc )
GENTPROT( sqrtrsc )
GENTPROT( subsc )
GENTPROT( invertsc )
GENTPROT( negsc )


#undef GENTPROT
Expand Down
1 change: 1 addition & 0 deletions frame/0/bli_l0_fpa.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ GENFRONT( divsc )
GENFRONT( mulsc )
GENFRONT( subsc )
GENFRONT( invertsc )
GENFRONT( negsc )
GENFRONT( sqrtsc )
GENFRONT( sqrtrsc )
GENFRONT( unzipsc )
Expand Down
1 change: 1 addition & 0 deletions frame/0/bli_l0_fpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ GENPROT( divsc )
GENPROT( mulsc )
GENPROT( subsc )
GENPROT( invertsc )
GENPROT( negsc )
GENPROT( sqrtsc )
GENPROT( sqrtrsc )
GENPROT( unzipsc )
Expand Down
1 change: 1 addition & 0 deletions frame/0/bli_l0_ft.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ INSERT_GENTDEF( addsc )
INSERT_GENTDEF( divsc )
INSERT_GENTDEF( subsc )
INSERT_GENTDEF( invertsc )
INSERT_GENTDEF( negsc )

// mulsc

Expand Down
1 change: 1 addition & 0 deletions frame/0/bli_l0_oapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ GENFRONT( divsc )
GENFRONT( mulsc )
GENFRONT( subsc )
GENFRONT( invertsc )
GENFRONT( negsc )


#undef GENFRONT
Expand Down
1 change: 1 addition & 0 deletions frame/0/bli_l0_oapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ GENPROT( sqrtsc )
GENPROT( sqrtrsc )
GENPROT( subsc )
GENPROT( invertsc )
GENPROT( negsc )


#undef GENPROT
Expand Down
23 changes: 23 additions & 0 deletions frame/0/bli_l0_tapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,29 @@ void PASTEMAC(ch,opname) \
INSERT_GENTFUNC_BASIC( invertsc, inverts )


#undef GENTFUNCR
#define GENTFUNCR( ctype, ctype_r, ch, chr, opname, kername ) \
\
void PASTEMAC(ch,opname) \
( \
conj_t conjchi, \
const ctype* chi, \
ctype* psi \
) \
{ \
bli_init_once(); \
\
ctype chi_conj; \
ctype_r chi_conj_r, chi_conj_i; \
\
PASTEMAC(ch,copycjs)( conjchi, *chi, chi_conj ); \
PASTEMAC(ch,gets)( chi_conj, chi_conj_r, chi_conj_i ); \
PASTEMAC(ch,sets)( -chi_conj_r, -chi_conj_i, *psi ); \
}

INSERT_GENTFUNCR_BASIC( negsc, inverts )


#undef GENTFUNC
#define GENTFUNC( ctype, ch, opname, kername ) \
\
Expand Down
1 change: 1 addition & 0 deletions frame/0/bli_l0_tapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ INSERT_GENTPROT_BASIC( divsc )
INSERT_GENTPROT_BASIC( mulsc )
INSERT_GENTPROT_BASIC( subsc )
INSERT_GENTPROT_BASIC( invertsc )
INSERT_GENTPROT_BASIC( negsc )


#undef GENTPROTR
Expand Down
1 change: 1 addition & 0 deletions frame/1d/bli_l1d_check.c
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ void PASTEMAC(opname,_check) \
GENFRONT( invscald )
GENFRONT( scald )
GENFRONT( setd )
GENFRONT( setrd )
GENFRONT( setid )
GENFRONT( shiftd )

Expand Down
1 change: 1 addition & 0 deletions frame/1d/bli_l1d_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void PASTEMAC(opname,_check) \
GENTPROT( invscald )
GENTPROT( scald )
GENTPROT( setd )
GENTPROT( setrd )
GENTPROT( setid )
GENTPROT( shiftd )

Expand Down
1 change: 1 addition & 0 deletions frame/1d/bli_l1d_fpa.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ GENFRONT( invertd )
GENFRONT( invscald )
GENFRONT( scald )
GENFRONT( setd )
GENFRONT( setrd )
GENFRONT( setid )
GENFRONT( shiftd )
GENFRONT( xpbyd )
Expand Down
1 change: 1 addition & 0 deletions frame/1d/bli_l1d_fpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ GENPROT( invertd )
GENPROT( invscald )
GENPROT( scald )
GENPROT( setd )
GENPROT( setrd )
GENPROT( setid )
GENPROT( shiftd )
GENPROT( xpbyd )
Expand Down
3 changes: 2 additions & 1 deletion frame/1d/bli_l1d_ft.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ INSERT_GENTDEF( invscald )
INSERT_GENTDEF( scald )
INSERT_GENTDEF( setd )

// setid
// setrd, setid

#undef GENTDEFR
#define GENTDEFR( ctype, ctype_r, ch, chr, opname, tsuf ) \
Expand All @@ -130,6 +130,7 @@ typedef void (*PASTECH(ch,opname,EX_SUF,tsuf)) \
BLIS_TAPI_EX_PARAMS \
);

INSERT_GENTDEFR( setrd )
INSERT_GENTDEFR( setid )

// shiftd
Expand Down
1 change: 1 addition & 0 deletions frame/1d/bli_l1d_oapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ void PASTEMAC(opname,EX_SUF) \
); \
}

GENFRONT( setrd )
GENFRONT( setid )


Expand Down
1 change: 1 addition & 0 deletions frame/1d/bli_l1d_oapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \
GENTPROT( invscald )
GENTPROT( scald )
GENTPROT( setd )
GENTPROT( setrd )
GENTPROT( setid )
GENTPROT( shiftd )

Expand Down
73 changes: 73 additions & 0 deletions frame/1d/bli_l1d_tapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,79 @@ void PASTEMAC(ch,opname,EX_SUF) \
dim_t n_elem; \
dim_t offx; \
inc_t incx; \
\
if ( bli_zero_dim2( m, n ) ) return; \
\
if ( bli_is_outside_diag( diagoffx, BLIS_NO_TRANSPOSE, m, n ) ) return; \
\
/* Determine the distance to the diagonals, the number of diagonal
elements, and the diagonal increments. */ \
bli_set_dims_incs_1d \
( \
diagoffx, \
m, n, rs_x, cs_x, \
&offx, &n_elem, &incx \
); \
\
/* Alternate implementation. (Substitute for remainder of function). */ \
/* for ( i = 0; i < n_elem; ++i ) \
{ \
ctype* chi11 = x1 + (i )*incx; \
\
PASTEMAC(ch,setrs)( *alpha, *chi11 ); \
} */ \
\
/* Acquire the address of the real component of the first element,
and scale the increment for use in the real domain if the data type is complex. */ \
x1 = ( ctype_r* )( x + offx ); \
if ( bli_is_complex( dt ) ) \
incx = 2*incx; \
\
/* Obtain a valid context from the gks if necessary. */ \
if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \
\
/* Query the context for the operation's kernel address. */ \
PASTECH(kername,_ker_ft) f = bli_cntx_get_ukr_dt( dt_r, kerid, cntx ); \
\
/* Invoke the kernel with the appropriate parameters. */ \
f \
( \
BLIS_NO_CONJUGATE, \
n_elem, \
( ctype_r* )alpha, \
x1, incx, \
( cntx_t* )cntx \
); \
}

INSERT_GENTFUNCR_BASIC( setrd, setv, BLIS_SETV_KER )


#undef GENTFUNCR
#define GENTFUNCR( ctype, ctype_r, ch, chr, opname, kername, kerid ) \
\
void PASTEMAC(ch,opname,EX_SUF) \
( \
doff_t diagoffx, \
dim_t m, \
dim_t n, \
const ctype_r* alpha, \
ctype* x, inc_t rs_x, inc_t cs_x \
BLIS_TAPI_EX_PARAMS \
) \
{ \
\
bli_init_once(); \
\
BLIS_TAPI_EX_DECLS \
\
const num_t dt = PASTEMAC(ch,type); \
const num_t dt_r = PASTEMAC(chr,type); \
\
ctype_r* x1; \
dim_t n_elem; \
dim_t offx; \
inc_t incx; \
\
/* If the datatype is real, the entire operation is a no-op. */ \
if ( bli_is_real( dt ) ) return; \
Expand Down
1 change: 1 addition & 0 deletions frame/1d/bli_l1d_tapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ BLIS_EXPORT_BLIS void PASTEMAC(ch,opname,EX_SUF) \
BLIS_TAPI_EX_PARAMS \
);

INSERT_GENTPROTR_BASIC( setrd )
INSERT_GENTPROTR_BASIC( setid )


Expand Down
18 changes: 14 additions & 4 deletions frame/1f/bli_l1f_check.c
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ void bli_dotaxpyv_check

void bli_dotxaxpyf_check
(
const obj_t* alpha,
const obj_t* alphaw,
const obj_t* alphax,
const obj_t* at,
const obj_t* a,
const obj_t* w,
Expand All @@ -302,7 +303,10 @@ void bli_dotxaxpyf_check

// Check object datatypes.

e_val = bli_check_noninteger_object( alpha );
e_val = bli_check_noninteger_object( alphaw );
bli_check_error_code( e_val );

e_val = bli_check_noninteger_object( alphax );
bli_check_error_code( e_val );

e_val = bli_check_floating_object( at );
Expand Down Expand Up @@ -345,7 +349,10 @@ void bli_dotxaxpyf_check

// Check object dimensions.

e_val = bli_check_scalar_object( alpha );
e_val = bli_check_scalar_object( alphaw );
bli_check_error_code( e_val );

e_val = bli_check_scalar_object( alphax );
bli_check_error_code( e_val );

e_val = bli_check_matrix_object( at );
Expand Down Expand Up @@ -397,7 +404,10 @@ void bli_dotxaxpyf_check

// Check object buffers (for non-NULLness).

e_val = bli_check_object_buffer( alpha );
e_val = bli_check_object_buffer( alphaw );
bli_check_error_code( e_val );

e_val = bli_check_object_buffer( alphax );
bli_check_error_code( e_val );

e_val = bli_check_object_buffer( at );
Expand Down
3 changes: 2 additions & 1 deletion frame/1f/bli_l1f_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ GENTPROT( dotaxpyv )
\
void PASTEMAC(opname,_check) \
( \
const obj_t* alpha, \
const obj_t* alphaw, \
const obj_t* alphax, \
const obj_t* at, \
const obj_t* a, \
const obj_t* w, \
Expand Down
3 changes: 2 additions & 1 deletion frame/1f/bli_l1f_ft.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ typedef void (*PASTECH(ch,opname,EX_SUF,tsuf)) \
conj_t conjx, \
dim_t m, \
dim_t b_n, \
const ctype* alpha, \
const ctype* alphaw, \
const ctype* alphax, \
const ctype* a, inc_t inca, inc_t lda, \
const ctype* w, inc_t incw, \
const ctype* x, inc_t incx, \
Expand Down
Loading