diff --git a/test/test_transform.py b/test/test_transform.py index cdd560dad..46d636900 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -1478,6 +1478,62 @@ class FooTag(Tag): assert t_unit.default_entrypoint.inames["i_0"].tags_of_type(FooTag) # fails +def test_reindexing_strided_access(ctx_factory): + import islpy as isl + + if not hasattr(isl.Set, "card"): + pytest.skip("No barvinok support") + + ctx = ctx_factory() + + tunit = lp.make_kernel( + "{[i, j]: 0<=j,i<10}", + """ + <> tmp[2*i, 2*j] = a[i, j] + out[i, j] = tmp[2*i, 2*j]**2 + """) + + tunit = lp.add_dtypes(tunit, {"a": "float64"}) + ref_tunit = tunit + + knl = lp.reindex_using_seghir_loechner_scheme(tunit.default_entrypoint, + "tmp") + tunit = tunit.with_kernel(knl) + + tv, = tunit.default_entrypoint.temporary_variables.values() + assert tv.shape == (100,) + + lp.auto_test_vs_ref(ref_tunit, ctx, tunit) + + +def test_reindexing_figurate(ctx_factory): + import islpy as isl + + if not hasattr(isl.Set, "card"): + pytest.skip("No barvinok support") + + ctx = ctx_factory() + + tunit = lp.make_kernel( + "{[i, j]: 0<=j<=i<10}", + """ + <> tmp[2*i, 2*j] = a[i, j] + out[i, j] = tmp[2*i, 2*j]**2 + """) + + tunit = lp.add_dtypes(tunit, {"a": "float64"}) + ref_tunit = tunit + + knl = lp.reindex_using_seghir_loechner_scheme(tunit.default_entrypoint, + "tmp") + tunit = tunit.with_kernel(knl) + + tv, = tunit.default_entrypoint.temporary_variables.values() + assert tv.shape == (55,) + + lp.auto_test_vs_ref(ref_tunit, ctx, tunit) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])