Skip to content

Commit

Permalink
share indexing values between areas
Browse files Browse the repository at this point in the history
  • Loading branch information
omlins committed Mar 21, 2023
1 parent d3dc138 commit 76959a0
Showing 1 changed file with 70 additions and 36 deletions.
106 changes: 70 additions & 36 deletions src/kernel_language.jl
Expand Up @@ -84,23 +84,24 @@ function loopopt(metadata_module::Module, caller::Module, indices::Union{Symbol,

if optdim == 3
oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, offset_spans, oz_spans, loopentrys = define_helper_variables(offset_mins, offset_maxs, optvars, optdim)
loopstart = minimum(values(loopentrys))
loopend = loopsize
shmem_symbols = define_shmem_symbols(oz_maxs, optvars, use_shmems, optdim)
shmem_exprs = define_shmem_exprs(shmem_symbols, optdim)
shmem_z_ranges = define_shmem_z_ranges(offsets_by_z, use_shmems, optdim)
shmem_loopentrys = define_shmem_loopentrys(loopentrys, shmem_z_ranges, offset_mins, optdim)
shmem_loopexits = define_shmem_loopexits(loopend, shmem_z_ranges, offset_maxs, optdim)
mainloopstart = (optimize_halo_read && !isempty(shmem_loopentrys)) ? minimum(values(shmem_loopentrys)) : loopstart
mainloopend = loopend # TODO: the second loop split leads to wrong results, probably due to a compiler bug. # mainloopend = (optimize_halo_read && !isempty(shmem_loopexits) ) ? maximum(values(shmem_loopexits) ) : loopend
ix, iy, iz = indices
tz_g = THREADIDS_VARNAMES[3]
rangelength_z = RANGELENGTHS_VARNAMES[3]
ranges = RANGES_VARNAME
range_z = :(($ranges[3])[$tz_g])
range_z_start = :(($ranges[3])[1])
i = gensym_world("i", @__MODULE__)
loopoffset = gensym_world("loopoffset", @__MODULE__)
loopstart = minimum(values(loopentrys))
loopend = loopsize
shmem_index_groups = define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars, use_shmems, optdim)
shmem_symbols = define_shmem_symbols(oz_maxs, optvars, use_shmems, shmem_index_groups, optdim)
shmem_exprs = define_shmem_exprs(shmem_symbols, optdim)
shmem_z_ranges = define_shmem_z_ranges(offsets_by_z, use_shmems, optdim)
shmem_loopentrys = define_shmem_loopentrys(loopentrys, shmem_z_ranges, offset_mins, optdim)
shmem_loopexits = define_shmem_loopexits(loopend, shmem_z_ranges, offset_maxs, optdim)
mainloopstart = (optimize_halo_read && !isempty(shmem_loopentrys)) ? minimum(values(shmem_loopentrys)) : loopstart
mainloopend = loopend # TODO: the second loop split leads to wrong results, probably due to a compiler bug. # mainloopend = (optimize_halo_read && !isempty(shmem_loopexits) ) ? maximum(values(shmem_loopexits) ) : loopend
ix, iy, iz = indices
tz_g = THREADIDS_VARNAMES[3]
rangelength_z = RANGELENGTHS_VARNAMES[3]
ranges = RANGES_VARNAME
range_z = :(($ranges[3])[$tz_g])
range_z_start = :(($ranges[3])[1])
i = gensym_world("i", @__MODULE__)
loopoffset = gensym_world("loopoffset", @__MODULE__)

for A in optvars
regqueue_tail = regqueue_tails[A]
Expand Down Expand Up @@ -135,9 +136,13 @@ $((quote
$ix_h2 = (@blockIdx().x-1)*@blockDim().x + $tx_h2 - $hx1 # ...
$iy_h = (@blockIdx().y-1)*@blockDim().y + $ty_h - $hy1 # ...
$iy_h2 = (@blockIdx().y-1)*@blockDim().y + $ty_h2 - $hy1 # ...
$A_head = @sharedMem(eltype($A), ($nx_l, $ny_l), $shmem_offset) # e.g. A_izp3 = @sharedMem(eltype(A), (nx_l, ny_l), +(nx_l_A * ny_l_A)*eltype(A))
end
for (A, s) in shmem_symbols for (shmem_offset, hx1, hx2, hy1, hy2, tx, ty, nx_l, ny_l, t_h, t_h2, tx_h, tx_h2, ty_h, ty_h2, ix_h, ix_h2, iy_h, iy_h2, A_head) = ((shmem_exprs[A][:offset], hx1s[A], hx2s[A], hy1s[A], hy2s[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:t_h], s[:t_h2], s[:tx_h], s[:tx_h2], s[:ty_h], s[:ty_h2], s[:ix_h], s[:ix_h2], s[:iy_h], s[:iy_h2], s[:A_head]),)
for vars in values(shmem_index_groups) for A in (vars[1],) for s in (shmem_symbols[A],) for (shmem_offset, hx1, hx2, hy1, hy2, tx, ty, nx_l, ny_l, t_h, t_h2, tx_h, tx_h2, ty_h, ty_h2, ix_h, ix_h2, iy_h, iy_h2, A_head) = ((shmem_exprs[A][:offset], hx1s[A], hx2s[A], hy1s[A], hy2s[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:t_h], s[:t_h2], s[:tx_h], s[:tx_h2], s[:ty_h], s[:ty_h2], s[:ix_h], s[:ix_h2], s[:iy_h], s[:iy_h2], s[:A_head]),)
)...
)
$((:( $A_head = @sharedMem(eltype($A), ($nx_l, $ny_l), $shmem_offset) # e.g. A_izp3 = @sharedMem(eltype(A), (nx_l, ny_l), +(nx_l_A * ny_l_A)*eltype(A))
)
for (A, s) in shmem_symbols for (shmem_offset, nx_l, ny_l, A_head) = ((shmem_exprs[A][:offset], s[:nx_l], s[:ny_l], s[:A_head]),)
)...
)
$((:( $reg = 0.0 # e.g. A_ixm1_iyp2_izp2 = 0.0
Expand Down Expand Up @@ -568,26 +573,55 @@ function define_helper_variables(offset_mins::Dict{Symbol, <:NTuple{3,Integer}},
return oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, offset_spans, oz_spans, loopentrys
end

function define_shmem_symbols(oz_maxs::Dict{Any, Any}, optvars::NTuple{N,Symbol} where N, use_shmems::Dict{Any, Any}, optdim::Integer)
sym = Dict(A => Dict() for A in optvars if use_shmems[A])
function define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars::NTuple{N,Symbol} where N, use_shmems::Dict{Any, Any}, optdim::Integer)
shmem_index_groups = Dict()
if optdim == 3
#TODO: use later the same simple for those who can use the same index.
for A in optvars
if use_shmems[A]
sym[A][:tx] = gensym_world("tx_$A", @__MODULE__)
sym[A][:ty] = gensym_world("ty_$A", @__MODULE__)
sym[A][:nx_l] = gensym_world("nx_l_$A", @__MODULE__)
sym[A][:ny_l] = gensym_world("ny_l_$A", @__MODULE__)
sym[A][:t_h] = gensym_world("t_h_$A", @__MODULE__)
sym[A][:t_h2] = gensym_world("t_h2_$A", @__MODULE__)
sym[A][:tx_h] = gensym_world("tx_h_$A", @__MODULE__)
sym[A][:tx_h2] = gensym_world("tx_h2_$A", @__MODULE__)
sym[A][:ty_h] = gensym_world("ty_h_$A", @__MODULE__)
sym[A][:ty_h2] = gensym_world("ty_h2_$A", @__MODULE__)
sym[A][:ix_h] = gensym_world("ix_h_$A", @__MODULE__)
sym[A][:ix_h2] = gensym_world("ix_h2_$A", @__MODULE__)
sym[A][:iy_h] = gensym_world("iy_h_$A", @__MODULE__)
sym[A][:iy_h2] = gensym_world("iy_h2_$A", @__MODULE__)
k = (hx1s[A], hy1s[A], hx2s[A], hy2s[A])
if !haskey(shmem_index_groups, k) shmem_index_groups[k] = (A,)
else shmem_index_groups[k] = (shmem_index_groups[k]..., A)
end
end
end
end
return shmem_index_groups
end

function define_shmem_symbols(oz_maxs::Dict{Any, Any}, optvars::NTuple{N,Symbol} where N, use_shmems::Dict{Any, Any}, shmem_index_groups, optdim::Integer)
sym = Dict(A => Dict() for A in optvars if use_shmems[A])
if optdim == 3
for vars in values(shmem_index_groups)
suffix = join(string.(vars), "_")
tx = gensym_world("tx_$suffix", @__MODULE__)
ty = gensym_world("ty_$suffix", @__MODULE__)
nx_l = gensym_world("nx_l_$suffix", @__MODULE__)
ny_l = gensym_world("ny_l_$suffix", @__MODULE__)
t_h = gensym_world("t_h_$suffix", @__MODULE__)
t_h2 = gensym_world("t_h2_$suffix", @__MODULE__)
tx_h = gensym_world("tx_h_$suffix", @__MODULE__)
tx_h2 = gensym_world("tx_h2_$suffix", @__MODULE__)
ty_h = gensym_world("ty_h_$suffix", @__MODULE__)
ty_h2 = gensym_world("ty_h2_$suffix", @__MODULE__)
ix_h = gensym_world("ix_h_$suffix", @__MODULE__)
ix_h2 = gensym_world("ix_h2_$suffix", @__MODULE__)
iy_h = gensym_world("iy_h_$suffix", @__MODULE__)
iy_h2 = gensym_world("iy_h2_$suffix", @__MODULE__)
for A in vars
sym[A][:tx] = tx
sym[A][:ty] = ty
sym[A][:nx_l] = nx_l
sym[A][:ny_l] = ny_l
sym[A][:t_h] = t_h
sym[A][:t_h2] = t_h2
sym[A][:tx_h] = tx_h
sym[A][:tx_h2] = tx_h2
sym[A][:ty_h] = ty_h
sym[A][:ty_h2] = ty_h2
sym[A][:ix_h] = ix_h
sym[A][:ix_h2] = ix_h2
sym[A][:iy_h] = iy_h
sym[A][:iy_h2] = iy_h2
sym[A][:A_head] = gensym_world(varname(A, (oz_maxs[A],); i="iz"), @__MODULE__)
end
end
Expand Down

0 comments on commit 76959a0

Please sign in to comment.