Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 296fa15e1bc7bd05c7e7c3a6333182777f8abf79
Author: Mason Remy <masonr@microsoft.com>
Date:   Tue Jan 17 18:52:45 2023 +0000

    Merged PR 3029: Work around constraint resolution issues with dynamic split size 1

    Work around constraint resolution issues with dynamic split size 1
  • Loading branch information
Lisa Ong committed Jan 18, 2023
1 parent 4e60e4a commit d92e34e
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
33 changes: 33 additions & 0 deletions accera/ir/src/AffineConstraintsHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,39 @@ namespace util
// an offset from that position, whereas in mlir::FlatAffineConstraints::getSliceBounds, the opposite appears to be done...
// It's not clear if this is a bug or just a bad choice of arg names.
auto [lbMap, ubMap] = cst.getLowerAndUpperBound(0 /* pos */, dimId /* offset */, 1 /* num */, cst.getNumDimIds(), localExprs, _context);

lbMap = mlir::removeDuplicateExprs(lbMap);
ubMap = mlir::removeDuplicateExprs(ubMap);

// TODO : we may need to write our own custom lower / upper bound resolution that ports portions of the builtin tools
// For now we depend on the builtin tools and some workarounds
// In some situations, seemingly more common with split sizes of 1, the constraints can simplify away some inequalities that are
// useful for detecting lower and upper bounds via the rather simple checks that getLowerAndUpperBound() performs.
// (Put another way: the code that simplifies constraints is smarter than the code that detects lower/upper bounds,
// and if the latter was smarter then a single bound might have been detected)
// When a single lower bound or a single upper bound isn't found, then multiple can be returned. This can cause issues for how
// Accera uses these bounds.
// As a workaround, if we have multiple results we attempt to find a single constant result and use that instead
if (lbMap.getNumResults() > 1)
{
auto lbConstOpt = cst.getConstantBound(mlir::IntegerPolyhedron::BoundType::LB, dimId);
if (lbConstOpt.hasValue())
{
auto lbConstExpr = mlir::getAffineConstantExpr(lbConstOpt.getValue(), _context);
lbMap = mlir::AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), lbConstExpr);
}
}

if (ubMap.getNumResults() > 1)
{
auto ubConstOpt = cst.getConstantBound(mlir::IntegerPolyhedron::BoundType::UB, dimId);
if (ubConstOpt.hasValue())
{
auto ubConstExpr = mlir::getAffineConstantExpr(ubConstOpt.getValue(), _context);
ubMap = mlir::AffineMap::get(ubMap.getNumDims(), ubMap.getNumSymbols(), ubConstExpr);
}
}

auto constraintOperands = tmpResolveCst.GetConstraintValuesForDimId(id);
auto simplifiedLBMap = lbMap;
auto simplifiedUBMap = ubMap;
Expand Down
133 changes: 133 additions & 0 deletions accera/python/accera/test/smoke_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5677,5 +5677,138 @@ def _():
after=correctness_check_values["post"],
)

def test_dynamic_size_redundant_split_1(self) -> None:
package_name = "test_dynamic_size_redundant_split_1"
split_size = 1

m_extent = Dimension()
input_arr = Array(role=Array.Role.INPUT, element_type=ScalarType.float32, shape=(m_extent,))
output_arr = Array(role=Array.Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent,))

nest = Nest((m_extent,))
i = nest.get_indices()
@nest.iteration_logic
def _():
output_arr[i] += input_arr[i]

sched = nest.create_schedule()
ii = sched.split(i, split_size)
iii = sched.split(ii, split_size)
sched.reorder(i, ii, iii)
plan = sched.create_plan()

# Create a package and add our function definition to it
package = Package()

fn = package.add(plan, args=(m_extent, input_arr, output_arr), base_name=package_name)

M_test = np.int64(1)
input_test = np.random.random((M_test,)).astype(np.float32)
output_test = np.random.random((M_test,)).astype(np.float32)
correctness_check_values = {
"pre": [M_test, input_test, output_test],
"post": [M_test, input_test, output_test + input_test],
}

# Build the HAT package
output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name
with verifiers.VerifyPackage(self, package_name, output_dir) as v:
package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False)

v.check_correctness(
fn.name,
before=correctness_check_values["pre"],
after=correctness_check_values["post"],
)

def test_dynamic_size_split_1(self) -> None:
package_name = "test_dynamic_size_split_1"
split_size = 1

m_extent = Dimension()
input_arr = Array(role=Array.Role.INPUT, element_type=ScalarType.float32, shape=(m_extent,))
output_arr = Array(role=Array.Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent,))

nest = Nest((m_extent,))
i = nest.get_indices()
@nest.iteration_logic
def _():
output_arr[i] += input_arr[i]

sched = nest.create_schedule()
ii = sched.split(i, split_size)
sched.reorder(i, ii)
plan = sched.create_plan()

# Create a package and add our function definition to it
package = Package()

fn = package.add(plan, args=(m_extent, input_arr, output_arr), base_name=package_name)

M_test = np.int64(1)
input_test = np.random.random((M_test,)).astype(np.float32)
output_test = np.random.random((M_test,)).astype(np.float32)
correctness_check_values = {
"pre": [M_test, input_test, output_test],
"post": [M_test, input_test, output_test + input_test],
}

# Build the HAT package
output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name
with verifiers.VerifyPackage(self, package_name, output_dir) as v:
package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False)

v.check_correctness(
fn.name,
before=correctness_check_values["pre"],
after=correctness_check_values["post"],
)

def test_dynamic_size_split_and_redundant_split_1(self) -> None:
package_name = "test_dynamic_size_split_and_redundant_split_1"
outer_split_size = 16
inner_split_size = 1

m_extent = Dimension()
input_arr = Array(role=Array.Role.INPUT, element_type=ScalarType.float32, shape=(m_extent,))
output_arr = Array(role=Array.Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent,))

nest = Nest((m_extent,))
i = nest.get_indices()
@nest.iteration_logic
def _():
output_arr[i] += input_arr[i]

sched = nest.create_schedule()
ii = sched.split(i, outer_split_size)
iii = sched.split(ii, inner_split_size)
iiii = sched.split(iii, inner_split_size)
sched.reorder(i, ii, iii, iiii)
plan = sched.create_plan()

# Create a package and add our function definition to it
package = Package()

fn = package.add(plan, args=(m_extent, input_arr, output_arr), base_name=package_name)

M_test = np.int64(37)
input_test = np.random.random((M_test,)).astype(np.float32)
output_test = np.random.random((M_test,)).astype(np.float32)
correctness_check_values = {
"pre": [M_test, input_test, output_test],
"post": [M_test, input_test, output_test + input_test],
}

# Build the HAT package
output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name
with verifiers.VerifyPackage(self, package_name, output_dir) as v:
package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False)

v.check_correctness(
fn.name,
before=correctness_check_values["pre"],
after=correctness_check_values["post"],
)

if __name__ == '__main__':
unittest.main(verbosity=10)

0 comments on commit d92e34e

Please sign in to comment.