Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit d6e136246b0dfc824c8111827cc6a0f166d3e2ea
Author: Mason Remy <masonr@microsoft.com>
Date:   Sat Mar 25 02:32:07 2023 +0000

    Merged PR 3181: Fix bug with reinterpret_cast of partially-dynamic array

    Fix bug with reinterpret_cast of partially-dynamic array

commit a4e8fa8874191c5c58475727937cfab4951427fc
Author: Mason Remy <masonr@microsoft.com>
Date:   Fri Mar 24 22:19:01 2023 +0000

    Merged PR 3180: Enable getting a memref shape from a memref_cast result

    Enable getting a memref shape from a memref_cast result

commit f3df546c84fd1c86d93c357811d43e34e35ae215
Author: Lisa Ong <onglisa@microsoft.com>
Date:   Fri Mar 24 17:28:30 2023 +0000

    Merged PR 3179: Fix vulkan-specific smoke test break

    Missing an import for test_vulkan_gpu_matmul(). This test code path is only exercised when vulkan is installed.

    ```
            format = self.PACKAGE_FORMAT if "VULKAN_SDK" in os.environ else Package.Format.HAT_STATIC
            with verifiers.VerifyPackage(self, "test_vulkan_gpu_matmul", TEST_PACKAGE_DIR):
                package.build(
                    name="test_vulkan_gpu_matmul", format=format, mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR
                )
    ```
  • Loading branch information
Lisa Ong committed Mar 27, 2023
1 parent 1127a60 commit c42ca38
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
7 changes: 7 additions & 0 deletions accera/ir/src/IRUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,13 @@ namespace util
return shape;
}

// TODO : more generalized dynamic shape detection
// If the memref is a memref_cast, then check the memref_cast source operand for a shape because the memref_cast does not change the shape
if (auto memrefCastOp = memref.getDefiningOp<ir::value::MemRefCastOp>())
{
memref = memrefCastOp.source();
}

// Currently this utility only supports dynamic memrefs that are alloc ops with shape args or
// function arguments with dimension size handles which are also function arguments
if (auto allocOp = memref.getDefiningOp<ir::value::AllocOp>())
Expand Down
29 changes: 29 additions & 0 deletions accera/python/accera/test/dsl_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,35 @@ def _():
_quiet=False
)

def test_reinterpret_cast_partially_dynamic_shape(self) -> None:
from accera import create_dimensions
test_name = "test_reinterpret_cast_partially_dynamic_shape"
package = Package()

M, N = create_dimensions()
A = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(5, M, N))
B = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(5, M, N))

nest = Nest((5, M, N))
indices = nest.get_indices()

@nest.iteration_logic
def _():
float_A = A._reinterpret_cast(ScalarType.float32)
B[indices] = float_A[indices]

package.add(nest, args=(M, N, A, B), base_name=test_name)

output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name
with verifiers.VerifyPackage(self, test_name, output_dir):
package.build(
test_name,
format=TEST_FORMAT,
mode=Package.Mode.RELEASE,
output_dir=output_dir,
_quiet=False
)

def test_subarray(self) -> None:
package = Package()

Expand Down
2 changes: 1 addition & 1 deletion accera/python/accera/test/smoke_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

from accera._lang_python import _MemoryLayout
from accera._lang_python._lang import Array as NativeArray
from accera._lang_python._lang import Dimension, _MemorySpace, _MMAShape, _If
from accera._lang_python._lang import Dimension, _MemorySpace, _MMAShape, _If, as_index
from accera._lang_python._lang._gpu import Barrier
from accera.samples import MatrixMultiplication
from accera.Targets import KNOWN_DEVICES
Expand Down
5 changes: 4 additions & 1 deletion accera/transforms/src/value/ValueToLLVMLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,10 @@ struct ValueMemRefCastOpLowering : public ValueLLVMOpConversionPattern<MemRefCas
}
else
{
size = rewriter.create<LLVM::MulOp>(loc, size, rewriter.create<memref::DimOp>(loc, op.getViewSource(), s.index()));
mlir::Value dimSize = rewriter.create<memref::DimOp>(loc, op.getViewSource(), s.index());
auto castOp = rewriter.create<mlir::UnrealizedConversionCastOp>(loc, llvmIndexType, dimSize);
mlir::Value intDimSize = castOp.getResult(0);
size = rewriter.create<LLVM::MulOp>(loc, size, intDimSize);
}
}
targetMemrefDesc.setSize(rewriter, loc, 0, size);
Expand Down

0 comments on commit c42ca38

Please sign in to comment.