diff --git a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs index ad07b3fe..253e68a8 100644 --- a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs +++ b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs @@ -61,6 +61,7 @@ public struct ExprNode private const uint MetaKeepWithoutNext = 0xFFFF0000u; // _data layout: bits [31:16]=ChildCount | [15:0]=ChildIdx (or full uint for inline constants) private const int DataCountShift = 16; + private const uint DataKeepWithoutChildIdx = 0xFFFF0000u; private const uint DataIdxMask = 0xFFFFu; private const int FlagsShift = 4; private const uint KindMask = 0x0Fu; @@ -139,6 +140,19 @@ internal void SetChildInfo(int childIdx, int childCount) => [MethodImpl(MethodImplOptions.AggressiveInlining)] internal bool HasFlag(byte flag) => (Flags & flag) != 0; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal bool HasSameHeaderExceptNext(ref ExprNode other) => + Type == other.Type && (_meta & MetaKeepWithoutNext) == (other._meta & MetaKeepWithoutNext); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal bool HasSameShapeExceptLinks(ref ExprNode other) => + HasSameHeaderExceptNext(ref other) && + (_data & DataKeepWithoutChildIdx) == (other._data & DataKeepWithoutChildIdx); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal bool HasSameShapeExceptNext(ref ExprNode other) => + HasSameHeaderExceptNext(ref other) && _data == other._data; + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal bool ShouldCloneWhenLinked() => ReferenceEquals(Obj, InlineValueMarker) || @@ -173,7 +187,7 @@ public LambdaClosureParameterUsage(ushort lambdaIdx, ushort parameterIdx, ushort } /// Stores an expression tree as flat nodes plus separate closure constants. -public struct ExprTree +public struct ExprTree : IEquatable { private static readonly object ClosureConstantMarker = new(); private const byte ParameterByRefFlag = 1; @@ -709,6 +723,29 @@ public SysExpr ToExpression() => [RequiresUnreferencedCode(FastExpressionCompiler.LightExpression.Trimming.Message)] public LightExpression ToLightExpression() => FastExpressionCompiler.LightExpression.FromSysExpressionConverter.ToLightExpression(ToExpression()); + /// Structurally compares two flat expression trees. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Equals(ExprTree other) => + new StructuralComparer().Eq(ref this, ref other); + + /// Structurally compares this tree with another object. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override bool Equals(object obj) => + obj is ExprTree other && Equals(other); + + /// Computes a content-addressable hash for the flat expression tree. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override int GetHashCode() => + new StructuralComparer().Hash(ref this); + + /// Determines whether two flat expression trees are structurally equal. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool operator ==(ExprTree left, ExprTree right) => left.Equals(right); + + /// Determines whether two flat expression trees are not structurally equal. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool operator !=(ExprTree left, ExprTree right) => !left.Equals(right); + [MethodImpl(MethodImplOptions.AggressiveInlining)] private int AddFactoryExpressionNode(Type type, object obj, ExpressionType nodeType, int child) => AddNode(type, obj, nodeType, ExprNodeKind.Expression, 0, CloneChild(child)); @@ -1606,6 +1643,526 @@ private static bool Contains(ref SmallList [MethodImpl(MethodImplOptions.AggressiveInlining)] private static ushort ToStoredUShortIdx(int idx) => checked((ushort)idx); + private struct StructuralComparer + { + private SmallList, NoArrayPool> _xParameterIds, _yParameterIds; + private SmallList, NoArrayPool> _xLabelIds, _yLabelIds; + private SmallList, NoArrayPool> _eqFrames; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Eq(ref ExprTree xTree, ref ExprTree yTree) + { + if (xTree.Nodes.Count == 0 || yTree.Nodes.Count == 0) + return xTree.Nodes.Count == yTree.Nodes.Count; + + var xIdx = xTree.RootIdx; + var yIdx = yTree.RootIdx; + var remainingSiblings = 0; + while (true) + { + ref var x = ref xTree.Nodes.GetSurePresentRef(xIdx); + ref var y = ref yTree.Nodes.GetSurePresentRef(yIdx); + if (x.Kind == ExprNodeKind.UInt16Pair) + { + if (!x.HasSameShapeExceptNext(ref y)) + return false; + } + else if (x.NodeType == ExpressionType.Constant) + { + if (!x.HasSameHeaderExceptNext(ref y)) + return false; + } + else if (!x.HasSameShapeExceptLinks(ref y)) + return false; + + var descendX = 0; + var descendY = 0; + var descendChildCount = 0; + var restoreXParameterCount = -1; + var restoreYParameterCount = -1; + + if (x.Kind != ExprNodeKind.UInt16Pair) + { + if (x.Kind == ExprNodeKind.LabelTarget) + { + if (!EqLabelTarget(ref x, ref y)) + return false; + } + else if (x.Kind == ExprNodeKind.CatchBlock) + { + restoreXParameterCount = _xParameterIds.Count; + restoreYParameterCount = _yParameterIds.Count; + descendX = x.ChildIdx; + descendY = y.ChildIdx; + descendChildCount = x.ChildCount - (x.HasFlag(CatchHasVariableFlag) ? 1 : 0); + if (x.HasFlag(CatchHasVariableFlag)) + { + ref var xv = ref xTree.Nodes.GetSurePresentRef(descendX); + ref var yv = ref yTree.Nodes.GetSurePresentRef(descendY); + if (!AreEquivalentParameterDeclarations(ref xv, ref yv)) + return false; + _xParameterIds.Add(ToStoredUShortIdx(xv.ChildIdx)); + _yParameterIds.Add(ToStoredUShortIdx(yv.ChildIdx)); + descendX = xv.NextIdx; + descendY = yv.NextIdx; + } + } + else + { + switch (x.NodeType) + { + case ExpressionType.Parameter: + if (!EqParameter(ref x, ref y)) + return false; + break; + + case ExpressionType.Constant: + if (!AreConstantsEqual(ref xTree, ref x, ref yTree, ref y)) + return false; + break; + + case ExpressionType.Lambda: + if (x.ChildCount == 0) + return false; + + restoreXParameterCount = _xParameterIds.Count; + restoreYParameterCount = _yParameterIds.Count; + descendX = x.ChildIdx; + descendY = y.ChildIdx; + descendChildCount = 1; + var xParameterIdx = xTree.Nodes.GetSurePresentRef(descendX).NextIdx; + var yParameterIdx = yTree.Nodes.GetSurePresentRef(descendY).NextIdx; + for (var i = 1; i < x.ChildCount; ++i) + { + ref var xp = ref xTree.Nodes.GetSurePresentRef(xParameterIdx); + ref var yp = ref yTree.Nodes.GetSurePresentRef(yParameterIdx); + if (!AreEquivalentParameterDeclarations(ref xp, ref yp)) + return false; + _xParameterIds.Add(ToStoredUShortIdx(xp.ChildIdx)); + _yParameterIds.Add(ToStoredUShortIdx(yp.ChildIdx)); + xParameterIdx = xp.NextIdx; + yParameterIdx = yp.NextIdx; + } + break; + + case ExpressionType.Block: + if (x.ChildCount == 0) + return false; + + restoreXParameterCount = _xParameterIds.Count; + restoreYParameterCount = _yParameterIds.Count; + descendX = x.ChildIdx; + descendY = y.ChildIdx; + descendChildCount = 1; + if (x.ChildCount == 2) + { + ref var xVariables = ref xTree.Nodes.GetSurePresentRef(descendX); + ref var yVariables = ref yTree.Nodes.GetSurePresentRef(descendY); + if (xVariables.Kind != ExprNodeKind.ChildList || yVariables.Kind != ExprNodeKind.ChildList || + xVariables.ChildCount != yVariables.ChildCount) + return false; + + var xVariableIdx = xVariables.ChildIdx; + var yVariableIdx = yVariables.ChildIdx; + for (var i = 0; i < xVariables.ChildCount; ++i) + { + ref var xv = ref xTree.Nodes.GetSurePresentRef(xVariableIdx); + ref var yv = ref yTree.Nodes.GetSurePresentRef(yVariableIdx); + if (!AreEquivalentParameterDeclarations(ref xv, ref yv)) + return false; + _xParameterIds.Add(ToStoredUShortIdx(xv.ChildIdx)); + _yParameterIds.Add(ToStoredUShortIdx(yv.ChildIdx)); + xVariableIdx = xv.NextIdx; + yVariableIdx = yv.NextIdx; + } + + descendX = xVariables.NextIdx; + descendY = yVariables.NextIdx; + } + break; + + default: + if (!EqObj(ref x, ref y)) + return false; + if (x.ChildCount != 0) + { + descendX = x.ChildIdx; + descendY = y.ChildIdx; + descendChildCount = x.ChildCount; + } + break; + } + } + } + + if (descendChildCount != 0) + { + _eqFrames.Add(new TraversalFrame(x.NextIdx, y.NextIdx, remainingSiblings, restoreXParameterCount, restoreYParameterCount)); + xIdx = descendX; + yIdx = descendY; + remainingSiblings = descendChildCount - 1; + continue; + } + + var advanced = false; + while (true) + { + if (remainingSiblings != 0) + { + xIdx = x.NextIdx; + yIdx = y.NextIdx; + remainingSiblings--; + advanced = true; + break; + } + + if (_eqFrames.Count == 0) + return true; + + var frame = _eqFrames[_eqFrames.Count - 1]; + _eqFrames.Count--; + RestoreParameterScope(frame.XParameterCount, frame.YParameterCount); + if (frame.RemainingSiblingsAfterNode != 0) + { + xIdx = frame.XNextIdx; + yIdx = frame.YNextIdx; + remainingSiblings = frame.RemainingSiblingsAfterNode - 1; + advanced = true; + break; + } + } + if (advanced) + continue; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Hash(ref ExprTree tree) => + tree.Nodes.Count == 0 ? 0 : HashNode(ref tree, tree.RootIdx); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int Combine(int h1, int h2) => + unchecked(h1 ^ (h2 + (int)0x9e3779b9 + (h1 << 6) + (h1 >> 2))); + + private bool EqParameter(ref ExprNode x, ref ExprNode y) + { + var xId = ToStoredUShortIdx(x.ChildIdx); + for (var i = 0; i < _xParameterIds.Count; ++i) + if (_xParameterIds[i] == xId) + return _yParameterIds[i] == ToStoredUShortIdx(y.ChildIdx); + + return x.HasFlag(ParameterByRefFlag) == y.HasFlag(ParameterByRefFlag) && + Equals(x.Obj, y.Obj); + } + + private bool EqLabelTarget(ref ExprNode x, ref ExprNode y) + { + var xId = ToStoredUShortIdx(x.ChildIdx); + for (var i = 0; i < _xLabelIds.Count; ++i) + if (_xLabelIds[i] == xId) + return _yLabelIds[i] == ToStoredUShortIdx(y.ChildIdx); + + _xLabelIds.Add(xId); + _yLabelIds.Add(ToStoredUShortIdx(y.ChildIdx)); + return Equals(x.Obj, y.Obj); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool AreEquivalentParameterDeclarations(ref ExprNode x, ref ExprNode y) => + x.NodeType == ExpressionType.Parameter && + y.NodeType == ExpressionType.Parameter && + x.HasSameShapeExceptLinks(ref y); + + private static bool EqObj(ref ExprNode x, ref ExprNode y) + { + return ReferenceEquals(x.Obj, y.Obj) || Equals(x.Obj, y.Obj); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void RestoreParameterScope(int xParameterCount, int yParameterCount) + { + if (xParameterCount >= 0) + _xParameterIds.Count = xParameterCount; + if (yParameterCount >= 0) + _yParameterIds.Count = yParameterCount; + } + + private int HashNode(ref ExprTree tree, int idx) + { + ref var node = ref tree.Nodes.GetSurePresentRef(idx); + if (node.Kind == ExprNodeKind.LabelTarget) + return Combine(Combine((int)node.Kind, node.Type?.GetHashCode() ?? 0), node.Obj?.GetHashCode() ?? 0); + + if (node.Kind == ExprNodeKind.CatchBlock) + return HashCatchBlock(ref tree, idx, ref node); + + if (node.Kind == ExprNodeKind.UInt16Pair) + return Combine(Combine((int)node.Kind, node.ChildIdx), node.ChildCount); + + var h = Combine(Combine((int)node.Kind, (int)node.NodeType), node.Type?.GetHashCode() ?? 0); + h = Combine(h, node.Flags); + + switch (node.NodeType) + { + case ExpressionType.Parameter: + { + var id = ToStoredUShortIdx(node.ChildIdx); + for (var i = 0; i < _xParameterIds.Count; ++i) + if (_xParameterIds[i] == id) + return Combine(h, i); + return Combine(h, node.Obj?.GetHashCode() ?? 0); + } + + case ExpressionType.Constant: + return Combine(h, GetConstantHashCode(ref tree, ref node)); + + case ExpressionType.Lambda: + return HashLambda(ref tree, idx, h); + + case ExpressionType.Block: + return HashBlock(ref tree, idx, h); + } + + h = Combine(h, node.Obj?.GetHashCode() ?? 0); + var childIdx = node.ChildIdx; + for (var i = 0; i < node.ChildCount; ++i) + { + h = Combine(h, HashNode(ref tree, childIdx)); + childIdx = tree.Nodes.GetSurePresentRef(childIdx).NextIdx; + } + return h; + } + + private int HashLambda(ref ExprTree tree, int idx, int h) + { + var scopeCount = _xParameterIds.Count; + ref var node = ref tree.Nodes.GetSurePresentRef(idx); + var bodyIdx = node.ChildIdx; + var parameterIdx = tree.Nodes.GetSurePresentRef(bodyIdx).NextIdx; + for (var i = 1; i < node.ChildCount; ++i) + { + ref var parameter = ref tree.Nodes.GetSurePresentRef(parameterIdx); + _xParameterIds.Add(ToStoredUShortIdx(parameter.ChildIdx)); + h = Combine(h, Combine(parameter.Type?.GetHashCode() ?? 0, parameter.HasFlag(ParameterByRefFlag) ? 1 : 0)); + parameterIdx = parameter.NextIdx; + } + + h = Combine(h, HashNode(ref tree, bodyIdx)); + _xParameterIds.Count = scopeCount; + return h; + } + + private int HashBlock(ref ExprTree tree, int idx, int h) + { + var scopeCount = _xParameterIds.Count; + ref var node = ref tree.Nodes.GetSurePresentRef(idx); + var bodyListIdx = node.ChildIdx; + if (node.ChildCount == 2) + { + ref var variables = ref tree.Nodes.GetSurePresentRef(bodyListIdx); + var variableIdx = variables.ChildIdx; + for (var i = 0; i < variables.ChildCount; ++i) + { + ref var variable = ref tree.Nodes.GetSurePresentRef(variableIdx); + _xParameterIds.Add(ToStoredUShortIdx(variable.ChildIdx)); + h = Combine(h, Combine(variable.Type?.GetHashCode() ?? 0, variable.HasFlag(ParameterByRefFlag) ? 1 : 0)); + variableIdx = variable.NextIdx; + } + bodyListIdx = variables.NextIdx; + } + + h = Combine(h, HashNode(ref tree, bodyListIdx)); + _xParameterIds.Count = scopeCount; + return h; + } + + private int HashCatchBlock(ref ExprTree tree, int idx, ref ExprNode node) + { + var h = Combine(Combine((int)node.Kind, node.Type?.GetHashCode() ?? 0), node.Flags); + var scopeCount = _xParameterIds.Count; + var childIdx = 0; + var catchChildIdx = node.ChildIdx; + if (node.HasFlag(CatchHasVariableFlag)) + { + ref var variable = ref tree.Nodes.GetSurePresentRef(catchChildIdx); + _xParameterIds.Add(ToStoredUShortIdx(variable.ChildIdx)); + h = Combine(h, Combine(variable.Type?.GetHashCode() ?? 0, variable.HasFlag(ParameterByRefFlag) ? 1 : 0)); + catchChildIdx = variable.NextIdx; + childIdx++; + } + + h = Combine(h, HashNode(ref tree, catchChildIdx)); + catchChildIdx = tree.Nodes.GetSurePresentRef(catchChildIdx).NextIdx; + childIdx++; + if (node.HasFlag(CatchHasFilterFlag)) + h = Combine(h, HashNode(ref tree, catchChildIdx)); + + _xParameterIds.Count = scopeCount; + return h; + } + + private static int GetConstantHashCode(ref ExprTree tree, ref ExprNode node) + { + if (ReferenceEquals(node.Obj, ExprNode.InlineValueMarker)) + return GetInlineConstantHashCode(node.Type, node.InlineValue); + return GetStoredConstantValue(ref tree, ref node)?.GetHashCode() ?? 0; + } + + private static bool AreConstantsEqual(ref ExprTree xTree, ref ExprNode x, ref ExprTree yTree, ref ExprNode y) + { + var xObj = GetStoredConstantValue(ref xTree, ref x); + var yObj = GetStoredConstantValue(ref yTree, ref y); + if (!ReferenceEquals(x.Obj, ExprNode.InlineValueMarker) && !ReferenceEquals(y.Obj, ExprNode.InlineValueMarker)) + return ReferenceEquals(xObj, yObj) || Equals(xObj, yObj); + + if (x.Type.IsEnum) + { + if (ReferenceEquals(x.Obj, ExprNode.InlineValueMarker) && ReferenceEquals(y.Obj, ExprNode.InlineValueMarker)) + return x.InlineValue == y.InlineValue; + return Type.GetTypeCode(Enum.GetUnderlyingType(x.Type)) switch + { + TypeCode.Byte => GetInlineOrConvertedByte(ref xTree, ref x) == GetInlineOrConvertedByte(ref yTree, ref y), + TypeCode.SByte => GetInlineOrConvertedSByte(ref xTree, ref x) == GetInlineOrConvertedSByte(ref yTree, ref y), + TypeCode.Char => GetInlineOrConvertedChar(ref xTree, ref x) == GetInlineOrConvertedChar(ref yTree, ref y), + TypeCode.Int16 => GetInlineOrConvertedInt16(ref xTree, ref x) == GetInlineOrConvertedInt16(ref yTree, ref y), + TypeCode.UInt16 => GetInlineOrConvertedUInt16(ref xTree, ref x) == GetInlineOrConvertedUInt16(ref yTree, ref y), + TypeCode.Int32 => GetInlineOrConvertedInt32(ref xTree, ref x) == GetInlineOrConvertedInt32(ref yTree, ref y), + TypeCode.UInt32 => GetInlineOrConvertedUInt32(ref xTree, ref x) == GetInlineOrConvertedUInt32(ref yTree, ref y), + var tc => FlatExpressionThrow.UnsupportedInlineConstantType(x.Type, tc) + }; + } + + return Type.GetTypeCode(x.Type) switch + { + TypeCode.Boolean => GetInlineOrStoredBoolean(ref xTree, ref x) == GetInlineOrStoredBoolean(ref yTree, ref y), + TypeCode.Byte => GetInlineOrStoredByte(ref xTree, ref x) == GetInlineOrStoredByte(ref yTree, ref y), + TypeCode.SByte => GetInlineOrStoredSByte(ref xTree, ref x) == GetInlineOrStoredSByte(ref yTree, ref y), + TypeCode.Char => GetInlineOrStoredChar(ref xTree, ref x) == GetInlineOrStoredChar(ref yTree, ref y), + TypeCode.Int16 => GetInlineOrStoredInt16(ref xTree, ref x) == GetInlineOrStoredInt16(ref yTree, ref y), + TypeCode.UInt16 => GetInlineOrStoredUInt16(ref xTree, ref x) == GetInlineOrStoredUInt16(ref yTree, ref y), + TypeCode.Int32 => GetInlineOrStoredInt32(ref xTree, ref x) == GetInlineOrStoredInt32(ref yTree, ref y), + TypeCode.UInt32 => GetInlineOrStoredUInt32(ref xTree, ref x) == GetInlineOrStoredUInt32(ref yTree, ref y), + TypeCode.Single => GetInlineOrStoredSingle(ref xTree, ref x).Equals(GetInlineOrStoredSingle(ref yTree, ref y)), + _ => ReferenceEquals(xObj, yObj) || Equals(xObj, yObj) + }; + } + + private static object GetStoredConstantValue(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ClosureConstantMarker) ? tree.ClosureConstants[node.ChildIdx] : node.Obj; + + private static int GetInlineConstantHashCode(Type type, uint data) + { + if (type.IsEnum) + return Type.GetTypeCode(Enum.GetUnderlyingType(type)) switch + { + TypeCode.Byte => ((byte)data).GetHashCode(), + TypeCode.SByte => ((sbyte)(byte)data).GetHashCode(), + TypeCode.Char => ((char)(ushort)data).GetHashCode(), + TypeCode.Int16 => ((short)(ushort)data).GetHashCode(), + TypeCode.UInt16 => ((ushort)data).GetHashCode(), + TypeCode.Int32 => ((int)data).GetHashCode(), + TypeCode.UInt32 => data.GetHashCode(), + var tc => FlatExpressionThrow.UnsupportedInlineConstantType(type, tc) + }; + + return Type.GetTypeCode(type) switch + { + TypeCode.Boolean => (data != 0).GetHashCode(), + TypeCode.Byte => ((byte)data).GetHashCode(), + TypeCode.SByte => ((sbyte)(byte)data).GetHashCode(), + TypeCode.Char => ((char)(ushort)data).GetHashCode(), + TypeCode.Int16 => ((short)(ushort)data).GetHashCode(), + TypeCode.UInt16 => ((ushort)data).GetHashCode(), + TypeCode.Int32 => ((int)data).GetHashCode(), + TypeCode.UInt32 => data.GetHashCode(), + TypeCode.Single => FloatBits.ToFloat(data).GetHashCode(), + _ => FlatExpressionThrow.UnsupportedInlineConstantType(type) + }; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool GetInlineOrStoredBoolean(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? node.InlineValue != 0 : (bool)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static byte GetInlineOrStoredByte(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (byte)node.InlineValue : (byte)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static sbyte GetInlineOrStoredSByte(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (sbyte)(byte)node.InlineValue : (sbyte)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static char GetInlineOrStoredChar(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (char)(ushort)node.InlineValue : (char)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static short GetInlineOrStoredInt16(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (short)(ushort)node.InlineValue : (short)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ushort GetInlineOrStoredUInt16(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (ushort)node.InlineValue : (ushort)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetInlineOrStoredInt32(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (int)node.InlineValue : (int)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint GetInlineOrStoredUInt32(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? node.InlineValue : (uint)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float GetInlineOrStoredSingle(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? FloatBits.ToFloat(node.InlineValue) : (float)GetStoredConstantValue(ref tree, ref node); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static byte GetInlineOrConvertedByte(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (byte)node.InlineValue : System.Convert.ToByte(GetStoredConstantValue(ref tree, ref node)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static sbyte GetInlineOrConvertedSByte(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (sbyte)(byte)node.InlineValue : System.Convert.ToSByte(GetStoredConstantValue(ref tree, ref node)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static char GetInlineOrConvertedChar(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (char)(ushort)node.InlineValue : System.Convert.ToChar(GetStoredConstantValue(ref tree, ref node)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static short GetInlineOrConvertedInt16(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (short)(ushort)node.InlineValue : System.Convert.ToInt16(GetStoredConstantValue(ref tree, ref node)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ushort GetInlineOrConvertedUInt16(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (ushort)node.InlineValue : System.Convert.ToUInt16(GetStoredConstantValue(ref tree, ref node)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetInlineOrConvertedInt32(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? (int)node.InlineValue : System.Convert.ToInt32(GetStoredConstantValue(ref tree, ref node)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint GetInlineOrConvertedUInt32(ref ExprTree tree, ref ExprNode node) => + ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) ? node.InlineValue : System.Convert.ToUInt32(GetStoredConstantValue(ref tree, ref node)); + + private struct TraversalFrame + { + public readonly int XNextIdx; + public readonly int YNextIdx; + public readonly int RemainingSiblingsAfterNode; + public readonly int XParameterCount; + public readonly int YParameterCount; + + public TraversalFrame(int xNextIdx, int yNextIdx, int remainingSiblingsAfterNode, int xParameterCount, int yParameterCount) + { + XNextIdx = xNextIdx; + YNextIdx = yNextIdx; + RemainingSiblingsAfterNode = remainingSiblingsAfterNode; + XParameterCount = xParameterCount; + YParameterCount = yParameterCount; + } + } + } + /// Reconstructs System.Linq nodes from the flat representation while reusing parameter and label identities. private struct Reader { diff --git a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs index 3e3ff694..98084bd8 100644 --- a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs +++ b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs @@ -55,7 +55,11 @@ public int Run() Flat_blocks_with_variables_tracked_from_expression_conversion(); Flat_goto_and_label_nodes_tracked_from_expression_conversion(); Flat_try_catch_nodes_tracked_from_expression_conversion(); - return 38; + Flat_equal_lambdas_with_different_parameter_names_are_structurally_equal_and_hash_equal(); + Flat_equal_nested_lambdas_with_captures_are_structurally_equal_and_hash_equal(); + Flat_standalone_parameters_use_name_in_structural_equality(); + Flat_structural_hash_supports_dictionary_lookup(); + return 42; } @@ -1023,5 +1027,56 @@ public void Flat_try_catch_nodes_tracked_from_expression_conversion() Asserts.AreEqual(1, fe.TryCatchNodes.Count); } + + public void Flat_equal_lambdas_with_different_parameter_names_are_structurally_equal_and_hash_equal() + { + var x = Parameter(typeof(int), "x"); + var left = Lambda>(Add(x, Constant(1)), x).ToFlatExpression(); + + var y = Parameter(typeof(int), "y"); + var right = Lambda>(Add(y, Constant(1)), y).ToFlatExpression(); + + Asserts.IsTrue(left.Equals(right)); + Asserts.IsTrue(left == right); + Asserts.AreEqual(left.GetHashCode(), right.GetHashCode()); + } + + public void Flat_equal_nested_lambdas_with_captures_are_structurally_equal_and_hash_equal() + { + var x = Parameter(typeof(int), "x"); + var left = Lambda>>( + Lambda>(Add(x, Constant(1))), + x).ToFlatExpression(); + + var y = Parameter(typeof(int), "value"); + var right = Lambda>>( + Lambda>(Add(y, Constant(1))), + y).ToFlatExpression(); + + Asserts.IsTrue(left.Equals(right)); + Asserts.AreEqual(left.GetHashCode(), right.GetHashCode()); + } + + public void Flat_standalone_parameters_use_name_in_structural_equality() + { + var left = Parameter(typeof(int), "x").ToFlatExpression(); + var right = Parameter(typeof(int), "y").ToFlatExpression(); + + Asserts.IsFalse(left.Equals(right)); + } + + public void Flat_structural_hash_supports_dictionary_lookup() + { + var x = Parameter(typeof(int), "x"); + var key = Lambda>(Add(x, Constant(1)), x).ToFlatExpression(); + var dict = new Dictionary { [key] = "found" }; + + var lookup = default(ExprTree); + var y = lookup.ParameterOf("arg"); + lookup.RootIdx = lookup.Lambda>(lookup.Add(y, lookup.ConstantInt(1)), y); + + Asserts.IsTrue(dict.TryGetValue(lookup, out var value)); + Asserts.AreEqual("found", value); + } } }