Skip to content

Commit

Permalink
Construct initial happens_afters using isl.PwAff.lt_map instead of lt…
Browse files Browse the repository at this point in the history
…_set
  • Loading branch information
a-alveyblanc committed Mar 26, 2024
1 parent 0b5ce62 commit a9d1775
Showing 1 changed file with 29 additions and 43 deletions.
72 changes: 29 additions & 43 deletions loopy/kernel/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,76 +41,62 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> LoopKernel:
"""

new_insns = []

for iafter, insn_after in enumerate(knl.instructions):

if iafter == 0:
new_insns.append(insn_after)

else:

insn_before = knl.instructions[iafter - 1]
shared_inames = insn_after.within_inames & insn_before.within_inames
insn_before = knl.instructions[iafter-1]

domain_before = knl.get_inames_domain(insn_before.within_inames)
domain_after = knl.get_inames_domain(insn_after.within_inames)
happens_before = isl.Map.from_domain_and_range(
domain_before, domain_after
)

for idim in range(happens_before.dim(dim_type.out)):
happens_before = happens_before.set_dim_name(
dim_type.out, idim,
happens_before.get_dim_name(dim_type.out, idim) + "'"
)

n_inames_before = happens_before.dim(dim_type.in_)
happens_before_set = happens_before.move_dims(
dim_type.out, 0,
dim_type.in_, 0,
n_inames_before).range()

shared_inames = insn_before.within_inames & insn_after.within_inames

happens_after = isl.Map.from_domain_and_range(
domain_before,
domain_after)

for idim in range(happens_after.dim(dim_type.out)):
happens_after = happens_after.set_dim_name(
dim_type.out,
idim,
happens_after.get_dim_name(dim_type.out, idim) + "'")

shared_inames_order_before = [
domain_before.get_dim_name(dim_type.out, idim)
for idim in range(domain_before.dim(dim_type.out))
if domain_before.get_dim_name(dim_type.out, idim)
in shared_inames
]
in shared_inames]

shared_inames_order_after = [
domain_after.get_dim_name(dim_type.out, idim)
for idim in range(domain_after.dim(dim_type.out))
if domain_after.get_dim_name(dim_type.out, idim)
in shared_inames
]
in shared_inames]

assert shared_inames_order_after == shared_inames_order_before
shared_inames_order = shared_inames_order_after

affs = isl.affs_from_space(happens_before_set.space)
affs_in = isl.affs_from_space(happens_after.domain().space)
affs_out = isl.affs_from_space(happens_after.range().space)

lex_set = isl.Set.empty(happens_before_set.space)
for iinnermost, innermost_iname in enumerate(shared_inames_order):

innermost_set = affs[innermost_iname].lt_set(
affs[innermost_iname+"'"]
)
lex_map = isl.Map.empty(happens_after.space)
for iinnermost, innermost_iname in enumerate(shared_inames):
innermost_map = affs_in[innermost_iname].lt_map(
affs_out[innermost_iname + "'"])

for outer_iname in shared_inames_order[:iinnermost]:
innermost_set = innermost_set & (
affs[outer_iname].eq_set(affs[outer_iname + "'"])
)

lex_set = lex_set | innermost_set
innermost_map = innermost_map & (
affs_in[outer_iname].eq_map(
affs_out[outer_iname + "'"]))

lex_map = isl.Map.from_range(lex_set).move_dims(
dim_type.in_, 0,
dim_type.out, 0,
n_inames_before)
lex_map = lex_map | innermost_map

happens_before = happens_before & lex_map
happens_after = happens_after & lex_map

new_happens_after = {
insn_before.id: HappensAfter(None, happens_before)
}
insn_before.id: HappensAfter(None, happens_after)}

insn_after = insn_after.copy(happens_after=new_happens_after)

Expand Down

0 comments on commit a9d1775

Please sign in to comment.