Skip to content

Commit

Permalink
Refactoring and clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
alazzaro committed Sep 28, 2018
1 parent a760ed7 commit 334ef70
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 203 deletions.
10 changes: 5 additions & 5 deletions src/mm/dbcsr_mm.F
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,9 @@ SUBROUTINE dbcsr_multiply_generic(transa, transb, &
! -----
!
use_rect_algo = .FALSE.
IF (numnodes .GT. 1 .AND. nprows .EQ. numnodes .AND. &
npcols .EQ. 1 .AND. transa .EQ. dbcsr_transpose .AND. &
(.NOT. product_reindex) .AND. &
IF (nprows .EQ. numnodes .AND. npcols .EQ. 1 .AND. &
transa .EQ. dbcsr_transpose .AND. &
(.NOT. product_reindex) .AND. &
(.NOT. keep_sparsity) .AND. &
(.NOT. dbcsr_cfg%use_mpi_rma) .AND. &
dbcsr_nfullrows_total(matrix_a) .GT. dbcsr_nfullrows_total(matrix_c) .AND. &
Expand Down Expand Up @@ -844,8 +844,8 @@ SUBROUTINE dbcsr_multiply_generic(transa, transb, &
! Make an empty product matrix for the rectangular algorithm.
! The row-block distribution is replicated over all processors.
IF (use_rect_algo) THEN
ALLOCATE (dist_rows(matrix_a%nblkcols_total), &
dist_cols(matrix_b%nblkcols_total))
ALLOCATE (dist_rows(matrix_left%nblkrows_total), &
dist_cols(matrix_right%nblkcols_total))
dist_rows(:) = mynode
dist_cols(:) = 0
CALL dbcsr_distribution_new(local_distribution, mp_obj, &
Expand Down
5 changes: 0 additions & 5 deletions src/mm/dbcsr_mm_cannon.F
Original file line number Diff line number Diff line change
Expand Up @@ -1715,11 +1715,6 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, &
CALL dbcsr_mm_multrec_phaseout(multrec(ithread)%p)
!$OMP BARRIER
CALL timeset(routineN//"_multrec", handle2)

! IF (local_mult) &
! print *,mynode,"AAAA",left_buffer_calc%mats(1, v_ki_left)%data_area%d%r_dp(1:4), &
! "BBBB", right_buffer_calc%mats(v_ki_right, 1)%data_area%d%r_dp(1:4)

CALL dbcsr_mm_multrec_multiply(multrec(ithread)%p, &
left=left_buffer_calc%mats(1, v_ki_left), &
right=right_buffer_calc%mats(v_ki_right, 1), &
Expand Down
4 changes: 2 additions & 2 deletions src/mm/dbcsr_mm_dist_operations.F
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ SUBROUTINE dbcsr_make_dists_dense(product_dist, left_rdist, right_rdist, &
IF (.NOT. dbcsr_distribution_has_threads(product_dist)) &
DBCSR_ABORT("Product distribution must have threads.")
tdist => array_data(dbcsr_distribution_thread_dist(product_dist))
old_m_dist = left_rdist%i%main%d%row_dist_block
old_n_dist = right_rdist%i%main%d%col_dist_block
old_m_dist = product_dist%d%row_dist_block
old_n_dist = product_dist%d%col_dist_block
old_k_vdist = right_rdist%i%vrow_dist
m_nbins = dbcsr_mp_nprows(product_dist%d%mp_env)
n_nbins = dbcsr_mp_npcols(product_dist%d%mp_env)
Expand Down
132 changes: 105 additions & 27 deletions src/mm/dbcsr_mm_reshape.F
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,54 @@ MODULE dbcsr_mm_reshape

#:include "../data/dbcsr.fypp"

use dbcsr_operations, only: dbcsr_copy, dbcsr_get_info,dbcsr_batched_add_anytype_begin,&
dbcsr_batched_add_anytype,&
dbcsr_batched_add_anytype_end

USE dbcsr_array_types, ONLY: array_i1d_obj
use dbcsr_operations, only: dbcsr_get_info, &
dbcsr_add_iter_s, &
dbcsr_add_iter_d, &
dbcsr_add_iter_c, &
dbcsr_add_iter_z
USE dbcsr_iterator_operations, ONLY: dbcsr_iterator_blocks_left,&
dbcsr_iterator_next_block,&
dbcsr_iterator_start,&
dbcsr_iterator_stop

USE dbcsr_block_access, ONLY: dbcsr_get_block_p, &
dbcsr_reserve_blocks, dbcsr_put_block

USE dbcsr_types, only: dbcsr_iterator,&
dbcsr_type,&
dbcsr_data_obj,&
dbcsr_scalar_type,&
dbcsr_mp_obj,&
${uselist(dkind1)}$
use dbcsr_work_operations, only: dbcsr_create, dbcsr_finalize
use dbcsr_work_operations, only: dbcsr_create, dbcsr_finalize, &
dbcsr_work_create
use dbcsr_dist_operations, only: dbcsr_get_stored_coordinates
USE dbcsr_methods, ONLY: dbcsr_get_data_type, dbcsr_release, &
dbcsr_distribution


USE dbcsr_kinds, ONLY: default_string_length
USE dbcsr_kinds, ONLY: ${uselist(kind1)}$
USE dbcsr_mpiwrap, ONLY: mp_alltoall,&
mp_environ,&
mp_irecv,&
mp_isend,&
mp_waitall


dbcsr_distribution,&
dbcsr_get_num_blocks,&
dbcsr_get_nze,&
dbcsr_nfullcols_local,&
dbcsr_nfullrows_local
USE dbcsr_kinds, ONLY: default_string_length,&
int_4,&
int_8
USE dbcsr_kinds, ONLY: ${uselist(kind1)}$
USE dbcsr_mpiwrap, ONLY: mp_alltoall,&
mp_environ,&
mp_irecv,&
mp_isend,&
mp_waitall
USE dbcsr_dist_methods, ONLY: dbcsr_distribution_mp

USE dbcsr_mp_methods, ONLY: dbcsr_mp_group, dbcsr_mp_mynode, dbcsr_mp_numnodes
USE dbcsr_data_methods, only: dbcsr_data_set_pointer, dbcsr_data_init, &
dbcsr_data_new, dbcsr_data_release, &
dbcsr_scalar_one, dbcsr_data_clear_pointer
USE dbcsr_index_operations, ONLY: dbcsr_repoint_index
USE dbcsr_mm_dist_operations, ONLY: dbcsr_get_global_row_map

#include "base/dbcsr_base_uses.f90"

USE dbcsr_index_operations, only: dbcsr_repoint_index
!$ USE OMP_LIB, ONLY: omp_get_max_threads, omp_get_thread_num, omp_get_num_threads

IMPLICIT NONE
PRIVATE
CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_mm_reshape'
Expand Down Expand Up @@ -94,7 +97,8 @@ SUBROUTINE dbcsr_reshape(matrix_in, matrix_out)
TYPE(dbcsr_iterator) :: iter
INTEGER :: col, row, blk, row_size, col_size, &
handle, mynode, iproc, numnodes, &
data_type, blk_p, mp_comm
data_type, blk_p, mp_comm, &
size_work, nblks_guess
TYPE(dbcsr_mp_obj) :: mp_obj
INTEGER, ALLOCATABLE, DIMENSION(:, :) :: num_send, num_recv
INTEGER, DIMENSION(2) :: total_num
Expand All @@ -103,6 +107,7 @@ SUBROUTINE dbcsr_reshape(matrix_in, matrix_out)
TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:) :: partial_matrices

INTEGER, PARAMETER :: imeta = 1, idata = 2
INTEGER(KIND=int_8) :: local_matrix_size

CALL timeset(routineN, handle)
!
Expand Down Expand Up @@ -168,14 +173,28 @@ SUBROUTINE dbcsr_reshape(matrix_in, matrix_out)
ENDDO
!
! Accumulate
CALL dbcsr_batched_add_anytype_begin(matrix_out, partial_matrices(0), total_num(idata), total_num(imeta))
!
! Pre-size work arrays of matrix_a to avoid continuous reallocation.
local_matrix_size = INT(dbcsr_nfullrows_local(matrix_out), KIND=int_8)* &
dbcsr_nfullcols_local(matrix_out)
size_work = MAX(0, INT(MIN(local_matrix_size-INT(dbcsr_get_nze(matrix_out), KIND=int_8), &
INT(total_num(idata), KIND=int_8)), KIND=int_4))
nblks_guess = MIN(matrix_in % nblkrows_total * matrix_in % nblkcols_total, total_num(idata))
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (matrix_out, size_work, nblks_guess)
CALL dbcsr_work_create(matrix_out, &
nblks_guess = nblks_guess, &
sizedata_guess = size_work, &
work_mutable=.FALSE.,create_hashes_par = .true.)

!$OMP END PARALLEL
DO iproc = 0, numnodes - 1
CALL dbcsr_batched_add_anytype(matrix_out, partial_matrices(iproc))
CALL dbcsr_accumulation_anytype(matrix_out, partial_matrices(iproc))
CALL dbcsr_release(partial_matrices(iproc))
CALL block_buffer_destroy(buffer_send(iproc))
CALL block_buffer_destroy(buffer_recv(iproc))
ENDDO
CALL dbcsr_batched_add_anytype_end(matrix_out)
CALL dbcsr_finalize(matrix_out)
CALL dbcsr_data_clear_pointer(blk_data)
CALL dbcsr_data_release(blk_data)

Expand Down Expand Up @@ -400,4 +419,63 @@ SUBROUTINE block_buffer_get_next_block_${dsuffix}$(buffer, ndata, row, col, bloc
END SUBROUTINE
#:endfor

! **************************************************************************************************
!> \brief Accumulation A += B
!> \param[in,out] matrix_a DBCSR matrix
!> \param[in] matrix_b DBCSR matrix
!> \param[in] alpha_scalar (optional) ...
!> \param[in] beta_scalar (optional) ...
!>
!> \param flop ...
! **************************************************************************************************
SUBROUTINE dbcsr_accumulation_anytype(matrix_a, matrix_b, flop)
TYPE(dbcsr_type), INTENT(INOUT) :: matrix_a
TYPE(dbcsr_type), INTENT(IN) :: matrix_b
INTEGER(KIND=int_8), INTENT(INOUT), OPTIONAL :: flop

CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_batched_add_anytype', &
routineP = moduleN//':'//routineN

INTEGER :: handle, iw
INTEGER(KIND=int_8) :: my_flop
LOGICAL :: do_scale
TYPE(dbcsr_iterator) :: iter
TYPE(dbcsr_scalar_type) :: my_beta_scalar
type(array_i1d_obj) :: map_row_g2l
!----------------------------------------------------------------------------

CALL timeset(routineN, handle)
my_beta_scalar = dbcsr_scalar_one(dbcsr_get_data_type(matrix_b))
do_scale = .FALSE.
IF (dbcsr_get_num_blocks(matrix_b) .GT. 0) THEN
! just to initialize if it was not initialized, map_row_g2l is not used
call dbcsr_get_global_row_map(matrix_a % dist, map_row_g2l)
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP PRIVATE (iter, iw) &
!$OMP SHARED (matrix_a, matrix_b) &
!$OMP SHARED (do_scale, my_beta_scalar) &
!$OMP REDUCTION (+ : my_flop)
iw = 1
!$ iw = omp_get_thread_num () + 1
CALL dbcsr_iterator_start(iter, matrix_b, &
shared=.TRUE., read_only=.TRUE., contiguous_pointers=.FALSE.)
SELECT CASE (dbcsr_get_data_type(matrix_b))
CASE (dbcsr_type_real_4)
CALL dbcsr_add_iter_s(matrix_a, matrix_b, iter, iw, do_scale, my_beta_scalar, my_flop, .true.)
CASE (dbcsr_type_real_8)
CALL dbcsr_add_iter_d(matrix_a, matrix_b, iter, iw, do_scale, my_beta_scalar, my_flop, .true.)
CASE (dbcsr_type_complex_4)
CALL dbcsr_add_iter_c(matrix_a, matrix_b, iter, iw, do_scale, my_beta_scalar, my_flop, .true.)
CASE (dbcsr_type_complex_8)
CALL dbcsr_add_iter_z(matrix_a, matrix_b, iter, iw, do_scale, my_beta_scalar, my_flop, .true.)
CASE default
DBCSR_ABORT("Invalid data type")
END SELECT
CALL dbcsr_iterator_stop(iter)
!$OMP END PARALLEL
IF (PRESENT(flop)) flop = flop+my_flop
ENDIF
CALL timestop(handle)
END SUBROUTINE dbcsr_accumulation_anytype

END MODULE dbcsr_mm_reshape
Loading

0 comments on commit 334ef70

Please sign in to comment.