Skip to content

Commit

Permalink
MAINT: use more conservative integer types for umath linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbrc committed May 21, 2015
1 parent a79d9d3 commit ad4aa25
Showing 1 changed file with 54 additions and 36 deletions.
90 changes: 54 additions & 36 deletions numpy/linalg/umath_linalg.c.src
Expand Up @@ -1128,6 +1128,7 @@ static void
npy_uint8 *tmp_buff = NULL;
size_t matrix_size;
size_t pivot_size;
size_t safe_m;
/* notes:
* matrix will need to be copied always, as factorization in lapack is
* made inplace
Expand All @@ -1138,8 +1139,9 @@ static void
*/
INIT_OUTER_LOOP_3
m = (fortran_int) dimensions[0];
matrix_size = m*m*sizeof(@typ@);
pivot_size = m*sizeof(fortran_int);
safe_m = m;
matrix_size = safe_m * safe_m * sizeof(@typ@);
pivot_size = safe_m * sizeof(fortran_int);
tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size);

if (tmp_buff)
Expand Down Expand Up @@ -1172,6 +1174,7 @@ static void
npy_uint8 *tmp_buff;
size_t matrix_size;
size_t pivot_size;
size_t safe_m;
/* notes:
* matrix will need to be copied always, as factorization in lapack is
* made inplace
Expand All @@ -1182,8 +1185,9 @@ static void
*/
INIT_OUTER_LOOP_2
m = (fortran_int) dimensions[0];
matrix_size = m*m*sizeof(@typ@);
pivot_size = m*sizeof(fortran_int);
safe_m = m;
matrix_size = safe_m * safe_m * sizeof(@typ@);
pivot_size = safe_m * sizeof(fortran_int);
tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size);

if (tmp_buff)
Expand Down Expand Up @@ -1252,14 +1256,15 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO,
fortran_int liwork = -1;
fortran_int info;
npy_uint8 *a, *w, *work, *iwork;
size_t alloc_size = N*(N+1)*sizeof(@typ@);
size_t safe_N = N;
size_t alloc_size = safe_N * (safe_N + 1) * sizeof(@typ@);

mem_buff = malloc(alloc_size);

if (!mem_buff)
goto error;
a = mem_buff;
w = mem_buff + N*N*sizeof(@typ@);
w = mem_buff + safe_N * safe_N * sizeof(@typ@);
LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
(@ftyp@*)a, &N, (@ftyp@*)w,
&query_work_size, &lwork,
Expand Down Expand Up @@ -1344,12 +1349,14 @@ init_@lapack_func@(EIGH_PARAMS_t *params,
fortran_int liwork = -1;
npy_uint8 *a, *w, *work, *rwork, *iwork;
fortran_int info;
size_t safe_N = N;

mem_buff = malloc(N*N*sizeof(@typ@)+N*sizeof(@basetyp@));
mem_buff = malloc(safe_N * safe_N * sizeof(@typ@) +
safe_N * sizeof(@basetyp@));
if (!mem_buff)
goto error;
a = mem_buff;
w = mem_buff+N*N*sizeof(@typ@);
w = mem_buff + safe_N * safe_N * sizeof(@typ@);

LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
(@ftyp@*)a, &N, (@fbasetyp@*)w,
Expand Down Expand Up @@ -1581,14 +1588,16 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *a, *b, *ipiv;
mem_buff = malloc(N*N*sizeof(@ftyp@) +
N*NRHS*sizeof(@ftyp@) +
N*sizeof(fortran_int));
size_t safe_N = N;
size_t safe_NRHS = NRHS;
mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@) +
safe_N * safe_NRHS*sizeof(@ftyp@) +
safe_N * sizeof(fortran_int));
if (!mem_buff)
goto error;
a = mem_buff;
b = a + N*N*sizeof(@ftyp@);
ipiv = b + N*NRHS*sizeof(@ftyp@);
b = a + safe_N * safe_N * sizeof(@ftyp@);
ipiv = b + safe_N * safe_NRHS * sizeof(@ftyp@);

params->A = a;
params->B = b;
Expand Down Expand Up @@ -1759,8 +1768,9 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *a;
size_t safe_N = N;

mem_buff = malloc(N*N*sizeof(@ftyp@));
mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@));
if (!mem_buff)
goto error;

Expand Down Expand Up @@ -1924,11 +1934,12 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n)
npy_uint8 *mem_buff=NULL;
npy_uint8 *mem_buff2=NULL;
npy_uint8 *a, *wr, *wi, *vlr, *vrr, *work, *w, *vl, *vr;
size_t a_size = n*n*sizeof(@typ@);
size_t wr_size = n*sizeof(@typ@);
size_t wi_size = n*sizeof(@typ@);
size_t vlr_size = jobvl=='V' ? n*n*sizeof(@typ@) : 0;
size_t vrr_size = jobvr=='V' ? n*n*sizeof(@typ@) : 0;
size_t safe_n = n;
size_t a_size = safe_n * safe_n * sizeof(@typ@);
size_t wr_size = safe_n * sizeof(@typ@);
size_t wi_size = safe_n * sizeof(@typ@);
size_t vlr_size = jobvl=='V' ? safe_n * safe_n * sizeof(@typ@) : 0;
size_t vrr_size = jobvr=='V' ? safe_n * safe_n * sizeof(@typ@) : 0;
size_t w_size = wr_size*2;
size_t vl_size = vlr_size*2;
size_t vr_size = vrr_size*2;
Expand Down Expand Up @@ -2120,11 +2131,12 @@ init_@lapack_func@(GEEV_PARAMS_t* params,
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *w, *vl, *vr, *work, *rwork;
size_t a_size = n*n*sizeof(@ftyp@);
size_t w_size = n*sizeof(@ftyp@);
size_t vl_size = jobvl=='V'? n*n*sizeof(@ftyp@) : 0;
size_t vr_size = jobvr=='V'? n*n*sizeof(@ftyp@) : 0;
size_t rwork_size = 2*n*sizeof(@realtyp@);
size_t safe_n = n;
size_t a_size = safe_n * safe_n * sizeof(@ftyp@);
size_t w_size = safe_n * sizeof(@ftyp@);
size_t vl_size = jobvl=='V'? safe_n * safe_n * sizeof(@ftyp@) : 0;
size_t vr_size = jobvr=='V'? safe_n * safe_n * sizeof(@ftyp@) : 0;
size_t rwork_size = 2 * safe_n * sizeof(@realtyp@);
size_t work_count = 0;
@typ@ work_size_query;
fortran_int do_size_query = -1;
Expand Down Expand Up @@ -2446,20 +2458,23 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *s, *u, *vt, *work, *iwork;
size_t a_size = (size_t)m*(size_t)n*sizeof(@ftyp@);
size_t safe_m = m;
size_t safe_n = n;
size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
fortran_int min_m_n = m<n?m:n;
size_t s_size = ((size_t)min_m_n)*sizeof(@ftyp@);
size_t safe_min_m_n = min_m_n;
size_t s_size = safe_min_m_n * sizeof(@ftyp@);
fortran_int u_row_count, vt_column_count;
size_t u_size, vt_size;
fortran_int work_count;
size_t work_size;
size_t iwork_size = 8*((size_t)min_m_n)*sizeof(fortran_int);
size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int);

if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
goto error;

u_size = ((size_t)u_row_count)*m*sizeof(@ftyp@);
vt_size = n*((size_t)vt_column_count)*sizeof(@ftyp@);
u_size = ((size_t)u_row_count) * safe_m * sizeof(@ftyp@);
vt_size = safe_n * ((size_t)vt_column_count) * sizeof(@ftyp@);

mem_buff = malloc(a_size + s_size + u_size + vt_size + iwork_size);

Expand Down Expand Up @@ -2558,20 +2573,23 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
npy_uint8 *a,*s, *u, *vt, *work, *rwork, *iwork;
size_t a_size, s_size, u_size, vt_size, work_size, rwork_size, iwork_size;
fortran_int u_row_count, vt_column_count, work_count;
size_t safe_m = m;
size_t safe_n = n;
fortran_int min_m_n = m<n?m:n;
size_t safe_min_m_n = min_m_n;

if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
goto error;

a_size = ((size_t)m)*((size_t)n)*sizeof(@ftyp@);
s_size = ((size_t)min_m_n)*sizeof(@frealtyp@);
u_size = ((size_t)u_row_count)*m*sizeof(@ftyp@);
vt_size = n*((size_t)vt_column_count)*sizeof(@ftyp@);
a_size = safe_m * safe_n * sizeof(@ftyp@);
s_size = safe_min_m_n * sizeof(@frealtyp@);
u_size = ((size_t)u_row_count) * safe_m * sizeof(@ftyp@);
vt_size = safe_n * ((size_t)vt_column_count) * sizeof(@ftyp@);
rwork_size = 'N'==jobz?
7*((size_t)min_m_n) :
(5*(size_t)min_m_n*(size_t)min_m_n + 5*(size_t)min_m_n);
(7 * safe_min_m_n) :
(5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n);
rwork_size *= sizeof(@ftyp@);
iwork_size = 8*((size_t)min_m_n)*sizeof(fortran_int);
iwork_size = 8 * safe_min_m_n* sizeof(fortran_int);

mem_buff = malloc(a_size +
s_size +
Expand Down

0 comments on commit ad4aa25

Please sign in to comment.