Skip to content

Commit

Permalink
Fix a couple of shared-memory errors
Browse files Browse the repository at this point in the history
  • Loading branch information
johnomotani committed Nov 23, 2023
1 parent 7210a0a commit c6b081b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 43 deletions.
9 changes: 8 additions & 1 deletion src/fokker_planck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ using LinearAlgebra: lu
using ..initial_conditions: enforce_boundary_conditions!
using ..type_definitions: mk_float, mk_int
using ..array_allocation: allocate_float, allocate_shared_float
using ..communication: MPISharedArray, global_rank
using ..communication: MPISharedArray, global_rank, _block_synchronize
using ..velocity_moments: integrate_over_vspace
using ..velocity_moments: get_density, get_upar, get_ppar, get_pperp, get_qpar, get_pressure, get_rmom
using ..calculus: derivative!, second_derivative!
Expand Down Expand Up @@ -312,6 +312,9 @@ function fokker_planck_collision_operator_weak_form!(ffs_in,ffsp_in,ms,msp,nussp
dHdvpa[ivpa,ivperp] = dHdvpa_Maxwellian(dens,upar,vth,vpa,vperp,ivpa,ivperp)
dHdvperp[ivpa,ivperp] = dHdvperp_Maxwellian(dens,upar,vth,vpa,vperp,ivpa,ivperp)
end
# Need to synchronize as these arrays may be read outside the locally-owned set of
# ivperp, ivpa indices in assemble_explicit_collision_operator_rhs_parallel!()
_block_synchronize()

This comment has been minimized.

Copy link
@mrhardman

mrhardman Nov 24, 2023

Collaborator

Nice, makes sense, this was a bug : )

else
calculate_rosenbluth_potentials_via_elliptic_solve!(GG,HH,dHdvpa,dHdvperp,
d2Gdvpa2,dGdvperp,d2Gdvperpdvpa,d2Gdvperp2,@view(ffsp_in[:,:]),
Expand All @@ -334,6 +337,10 @@ function fokker_planck_collision_operator_weak_form!(ffs_in,ffsp_in,ms,msp,nussp
dFdvpa[ivpa,ivperp] = dFdvpa_Maxwellian(dens,upar,vth,vpa,vperp,ivpa,ivperp)
dFdvperp[ivpa,ivperp] = dFdvperp_Maxwellian(dens,upar,vth,vpa,vperp,ivpa,ivperp)
end
# Need to synchronize as FF, dFdvpa, dFdvperp may be read outside the
# locally-owned set of ivperp, ivpa indices in
# assemble_explicit_collision_operator_rhs_parallel_analytical_inputs!()
_block_synchronize()

This comment has been minimized.

Copy link
@mrhardman

mrhardman Nov 24, 2023

Collaborator

Nice catch : )

assemble_explicit_collision_operator_rhs_parallel_analytical_inputs!(rhsc,rhsvpavperp,
FF,dFdvpa,dFdvperp,
d2Gdvpa2,d2Gdvperpdvpa,d2Gdvperp2,
Expand Down
87 changes: 45 additions & 42 deletions src/fokker_planck_calculus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1804,56 +1804,59 @@ end
function assemble_explicit_collision_operator_rhs_serial!(rhsc,pdfs,d2Gspdvpa2,d2Gspdvperpdvpa,
d2Gspdvperp2,dHspdvpa,dHspdvperp,ms,msp,nussp,
vpa,vperp,YY_arrays::YY_collision_operator_arrays)
# assemble RHS of collision operator
@. rhsc = 0.0

# loop over elements
for ielement_vperp in 1:vperp.nelement_local
YY0perp = YY_arrays.YY0perp[:,:,:,ielement_vperp]
YY1perp = YY_arrays.YY1perp[:,:,:,ielement_vperp]
YY2perp = YY_arrays.YY2perp[:,:,:,ielement_vperp]
YY3perp = YY_arrays.YY3perp[:,:,:,ielement_vperp]
begin_serial_region()
@serial_region begin
# assemble RHS of collision operator
@. rhsc = 0.0

for ielement_vpa in 1:vpa.nelement_local
YY0par = YY_arrays.YY0par[:,:,:,ielement_vpa]
YY1par = YY_arrays.YY1par[:,:,:,ielement_vpa]
YY2par = YY_arrays.YY2par[:,:,:,ielement_vpa]
YY3par = YY_arrays.YY3par[:,:,:,ielement_vpa]
# loop over elements
for ielement_vperp in 1:vperp.nelement_local
YY0perp = YY_arrays.YY0perp[:,:,:,ielement_vperp]
YY1perp = YY_arrays.YY1perp[:,:,:,ielement_vperp]
YY2perp = YY_arrays.YY2perp[:,:,:,ielement_vperp]
YY3perp = YY_arrays.YY3perp[:,:,:,ielement_vperp]

# loop over field positions in each element
for ivperp_local in 1:vperp.ngrid
for ivpa_local in 1:vpa.ngrid
ic_global, ivpa_global, ivperp_global = get_global_compound_index(vpa,vperp,ielement_vpa,ielement_vperp,ivpa_local,ivperp_local)
# carry out the matrix sum on each 2D element
for jvperpp_local in 1:vperp.ngrid
jvperpp = vperp.igrid_full[jvperpp_local,ielement_vperp]
for kvperpp_local in 1:vperp.ngrid
kvperpp = vperp.igrid_full[kvperpp_local,ielement_vperp]
for jvpap_local in 1:vpa.ngrid
jvpap = vpa.igrid_full[jvpap_local,ielement_vpa]
pdfjj = pdfs[jvpap,jvperpp]
for kvpap_local in 1:vpa.ngrid
kvpap = vpa.igrid_full[kvpap_local,ielement_vpa]
# first three lines represent parallel flux terms
# second three lines represent perpendicular flux terms
rhsc[ic_global] += (YY0perp[kvperpp_local,jvperpp_local,ivperp_local]*YY2par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*d2Gspdvpa2[kvpap,kvperpp] +
YY3perp[kvperpp_local,jvperpp_local,ivperp_local]*YY1par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*d2Gspdvperpdvpa[kvpap,kvperpp] -
2.0*(ms/msp)*YY0perp[kvperpp_local,jvperpp_local,ivperp_local]*YY1par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*dHspdvpa[kvpap,kvperpp] +
# end parallel flux, start of perpendicular flux
YY1perp[kvperpp_local,jvperpp_local,ivperp_local]*YY3par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*d2Gspdvperpdvpa[kvpap,kvperpp] +
YY2perp[kvperpp_local,jvperpp_local,ivperp_local]*YY0par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*d2Gspdvperp2[kvpap,kvperpp] -
2.0*(ms/msp)*YY1perp[kvperpp_local,jvperpp_local,ivperp_local]*YY0par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*dHspdvperp[kvpap,kvperpp])
for ielement_vpa in 1:vpa.nelement_local
YY0par = YY_arrays.YY0par[:,:,:,ielement_vpa]
YY1par = YY_arrays.YY1par[:,:,:,ielement_vpa]
YY2par = YY_arrays.YY2par[:,:,:,ielement_vpa]
YY3par = YY_arrays.YY3par[:,:,:,ielement_vpa]

# loop over field positions in each element
for ivperp_local in 1:vperp.ngrid
for ivpa_local in 1:vpa.ngrid
ic_global, ivpa_global, ivperp_global = get_global_compound_index(vpa,vperp,ielement_vpa,ielement_vperp,ivpa_local,ivperp_local)
# carry out the matrix sum on each 2D element
for jvperpp_local in 1:vperp.ngrid
jvperpp = vperp.igrid_full[jvperpp_local,ielement_vperp]
for kvperpp_local in 1:vperp.ngrid
kvperpp = vperp.igrid_full[kvperpp_local,ielement_vperp]
for jvpap_local in 1:vpa.ngrid
jvpap = vpa.igrid_full[jvpap_local,ielement_vpa]
pdfjj = pdfs[jvpap,jvperpp]
for kvpap_local in 1:vpa.ngrid
kvpap = vpa.igrid_full[kvpap_local,ielement_vpa]
# first three lines represent parallel flux terms
# second three lines represent perpendicular flux terms
rhsc[ic_global] += (YY0perp[kvperpp_local,jvperpp_local,ivperp_local]*YY2par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*d2Gspdvpa2[kvpap,kvperpp] +
YY3perp[kvperpp_local,jvperpp_local,ivperp_local]*YY1par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*d2Gspdvperpdvpa[kvpap,kvperpp] -
2.0*(ms/msp)*YY0perp[kvperpp_local,jvperpp_local,ivperp_local]*YY1par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*dHspdvpa[kvpap,kvperpp] +
# end parallel flux, start of perpendicular flux
YY1perp[kvperpp_local,jvperpp_local,ivperp_local]*YY3par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*d2Gspdvperpdvpa[kvpap,kvperpp] +
YY2perp[kvperpp_local,jvperpp_local,ivperp_local]*YY0par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*d2Gspdvperp2[kvpap,kvperpp] -
2.0*(ms/msp)*YY1perp[kvperpp_local,jvperpp_local,ivperp_local]*YY0par[kvpap_local,jvpap_local,ivpa_local]*pdfjj*dHspdvperp[kvpap,kvperpp])
end
end
end
end
end
end
end
end
end
end
# correct for minus sign due to integration by parts
# and multiply by the normalised collision frequency
@. rhsc *= -nussp
end
# correct for minus sign due to integration by parts
# and multiply by the normalised collision frequency
@. rhsc *= -nussp
return nothing
end

Expand Down

0 comments on commit c6b081b

Please sign in to comment.