Skip to content

Commit

Permalink
test reindex_using_seghir_loechner_scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Jun 28, 2022
1 parent 8fad4f8 commit 794b7cc
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions test/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,95 @@ class FooTag(Tag):
assert t_unit.default_entrypoint.inames["i_0"].tags_of_type(FooTag) # fails


def test_reindexing_strided_access(ctx_factory):
import islpy as isl

if not hasattr(isl.Set, "card"):
pytest.skip("No barvinok support")

ctx = ctx_factory()

tunit = lp.make_kernel(
"{[i, j]: 0<=j,i<10}",
"""
<> tmp[2*i, 2*j] = a[i, j]
out[i, j] = tmp[2*i, 2*j]**2
""")

tunit = lp.add_dtypes(tunit, {"a": "float64"})
ref_tunit = tunit

knl = lp.reindex_using_seghir_loechner_scheme(tunit.default_entrypoint,
"tmp")
tunit = tunit.with_kernel(knl)

tv, = tunit.default_entrypoint.temporary_variables.values()
assert tv.shape == (100,)

lp.auto_test_vs_ref(ref_tunit, ctx, tunit)


def test_reindexing_figurate(ctx_factory):
import islpy as isl

if not hasattr(isl.Set, "card"):
pytest.skip("No barvinok support")

ctx = ctx_factory()

tunit = lp.make_kernel(
"{[i, j]: 0<=j<=i<10}",
"""
<> tmp[2*i, 2*j] = a[i, j]
out[i, j] = tmp[2*i, 2*j]**2
""")

tunit = lp.add_dtypes(tunit, {"a": "float64"})
ref_tunit = tunit

knl = lp.reindex_using_seghir_loechner_scheme(tunit.default_entrypoint,
"tmp")
tunit = tunit.with_kernel(knl)

tv, = tunit.default_entrypoint.temporary_variables.values()
assert tv.shape == (55,)

lp.auto_test_vs_ref(ref_tunit, ctx, tunit)


def test_reindexing_figurate_parametric_shape(ctx_factory):
import islpy as isl
from loopy.symbolic import parse

if not hasattr(isl.Set, "card"):
pytest.skip("No barvinok support")

ctx = ctx_factory()

tunit = lp.make_kernel(
"{[i, j]: 0<=j<=i<n}",
"""
<> tmp[i, j] = a[i, j]
out[i, j] = tmp[i, j]**2
""",
assumptions="n > 0",
)

tunit = lp.add_dtypes(tunit, {"a": "float64"})
tunit = lp.set_temporary_address_space(tunit, "tmp",
lp.AddressSpace.GLOBAL)
ref_tunit = tunit

knl = lp.reindex_using_seghir_loechner_scheme(tunit.default_entrypoint,
"tmp")
tunit = tunit.with_kernel(knl)

tv, = tunit.default_entrypoint.temporary_variables.values()
assert tv.shape == (parse("(n + n**2) // 2"),)

lp.auto_test_vs_ref(ref_tunit, ctx, tunit, parameters={"n": 20})


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit 794b7cc

Please sign in to comment.