Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into python-binding
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu302 committed May 22, 2024
2 parents 57e0efa + 2b16815 commit cd68df3
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,18 @@ Value createLayerNorm(PatternRewriter &rewriter, Location loc, Value input,
return customCallOp.getResults()[0];
}

DenseFPElementsAttr getSplatFpElementsAttr(ShapedType type, float v) {
APFloat epsilonFloat = APFloat(v);
bool losesInfo = false;
auto status = epsilonFloat.convert(
type.getElementType().cast<FloatType>().getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
if (losesInfo || status != llvm::APFloatBase::opStatus::opOK) {
return nullptr;
}
return DenseFPElementsAttr::get(type, epsilonFloat);
}

Value createLayerNormMultiDim(PatternRewriter &rewriter, Location loc,
Value inputMD, Value gammaMD, Value betaMD,
ElementsAttr epsilon, ElementsAttr axis) {
Expand All @@ -189,11 +201,13 @@ Value createLayerNormMultiDim(PatternRewriter &rewriter, Location loc,
assert(inputMDRank == gammaMDType.getRank());
assert(inputMDRank == betaMDType.getRank());

auto gammaAttr = DenseElementsAttr::get(
auto gammaAttr = getSplatFpElementsAttr(
inputMDType.clone(llvm::SmallVector<int64_t>({cDim})), 1.0f);
assert(gammaAttr);
Value gamma = rewriter.create<TF::ConstOp>(loc, gammaAttr);
auto betaAttr = DenseElementsAttr::get(
auto betaAttr = getSplatFpElementsAttr(
inputMDType.clone(llvm::SmallVector<int64_t>({cDim})), 0.0f);
assert(betaAttr);
Value beta = rewriter.create<TF::ConstOp>(loc, betaAttr);

llvm::SmallVector<int64_t> inputShape = {1, cDim};
Expand Down Expand Up @@ -221,7 +235,8 @@ Value createLayerNormWithoutBeta(PatternRewriter &rewriter, Location loc,
Value input, Value gama, ElementsAttr epsilon,
ElementsAttr axis) {
auto gamaShapedType = gama.getType().cast<ShapedType>();
auto betaAttr = DenseElementsAttr::get(gamaShapedType, 0.0f);
auto betaAttr = getSplatFpElementsAttr(gamaShapedType, 0.0f);
assert(betaAttr);
auto betaOp = rewriter.create<TF::ConstOp>(loc, betaAttr);
Value beta = betaOp.getOutput();
return createLayerNorm(rewriter, loc, input, gama, beta, epsilon, axis);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py
index aee8251b..d157225a 100644
index 381f8f9a..75059770 100644
--- a/python/torch_mlir/extras/fx_importer.py
+++ b/python/torch_mlir/extras/fx_importer.py
@@ -927,6 +927,19 @@ class ContextCache:
@@ -52,6 +52,10 @@ from torch._subclasses import (
FakeTensor as TorchFakeTensor,
)

+from torch.distributed._functional_collectives import (
+ AsyncCollectiveTensor as TorchAsyncCollectiveTensor
+)
+
from torch.fx import (
Graph,
GraphModule,
@@ -924,6 +928,19 @@ class ContextCache:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparsity = node.meta.get("sparsity", None)
Expand All @@ -22,15 +33,15 @@ index aee8251b..d157225a 100644
except KeyError as e:
raise RuntimeError(
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
@@ -1038,6 +1051,7 @@ class GraphNodeImporter:
@@ -1035,6 +1052,7 @@ class GraphNodeImporter:
"_on_node_produced",
"_v",
"_multi_result_nodes",
+ "_list_return_nodes",
"fx_importer",
]

@@ -1061,6 +1075,9 @@ class GraphNodeImporter:
@@ -1058,6 +1076,9 @@ class GraphNodeImporter:
# They will have their getitem calls short-circuited.
self._multi_result_nodes: Set[torch_fx.Node] = set()

Expand All @@ -40,7 +51,7 @@ index aee8251b..d157225a 100644
def bind_node_value(
self,
node: Node,
@@ -1216,6 +1233,23 @@ class GraphNodeImporter:
@@ -1213,6 +1234,23 @@ class GraphNodeImporter:
f"notify developers if this case happens "
f"(at {loc})."
)
Expand All @@ -64,7 +75,7 @@ index aee8251b..d157225a 100644
else:
raise NotImplementedError(
f"General getitem access to non-multi-result ops"
@@ -1642,6 +1676,10 @@ class GraphNodeImporter:
@@ -1676,6 +1714,10 @@ class GraphNodeImporter:
# Unary return directly maps a single meta["val"] and cannot be subscripted.
# if "tensor_meta" is None, this will throw unsupported placeholder node error
result_types = [self._cc.node_val_to_type(node)]
Expand All @@ -75,3 +86,27 @@ index aee8251b..d157225a 100644
elif return_count == 0:
# Some torch ops do have 0 returns, and these are supported with ZeroResults
# op trait. Python bindings for IR creation allow us to pass empty result_types
@@ -1717,6 +1759,8 @@ def _make_vtensor_literal_op(
) -> Operation:
mapping = py_attr_tracker.track(tensor)
if mapping.is_empty:
+ # unwrap from TorchAsyncCollectiveTensor
+ tensor = tensor.elem if isinstance(tensor, TorchAsyncCollectiveTensor) else tensor
# check support for bfloat16
assert not (
tensor.dtype == torch.bfloat16 and ml_dtypes is None
@@ -1732,7 +1776,13 @@ def _make_vtensor_literal_op(
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
- np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
+
+ # NOTE: if we torch.export a torch.nn.Module under fake mode, the parameters in the fx.GraphModule will be FakeTensor.
+ # So we specifically handle FakeTensor here by randomly generating a tensor of the same shape and dtype.
+ if isinstance(tensor, TorchFakeTensor):
+ np_tensor = np.random.rand(*list(tensor.shape)).astype(npy_dtype)
+ else:
+ np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
# One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
# support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
# 0d tensors.

0 comments on commit cd68df3

Please sign in to comment.