Skip to content

Commit

Permalink
Preserve custom operators (#2125)
Browse files Browse the repository at this point in the history
* Preserve custom operators

This will keep custom operators on marked types whenever System.Linq.Expressions
is used, and the operator input types are marked.

The behavior is enabled by default, and can be disabled by passing
--disable-operator-discovery.

Addresses #1821

* Fix behavior for operators on nullable types

* Cleanup and PR feedback

- Avoid processing pending operators Dictionary if Linq.Expressions is unused
- Allocate this possibly-unused Dictionary lazily
- Use readonly field for always-used HashSet
- Rename markOperators -> seenLinqExpressions
- Clean up ProcessCustomOperators call to make intent more clear
- Add comments
- Check MetadataType.Int32 instead of searching BCL for Int32

* Remove unnecessary parens

* PR feedback

- seenLinqExpressions -> _seenLinqExpressions
- use List for pending operators instead of HashSet
  • Loading branch information
sbomer committed Jul 8, 2021
1 parent 35a1c74 commit 6b0da00
Show file tree
Hide file tree
Showing 12 changed files with 577 additions and 0 deletions.
215 changes: 215 additions & 0 deletions src/linker/Linker.Steps/DiscoverCustomOperatorsHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Diagnostics;
using Mono.Cecil;

namespace Mono.Linker.Steps
{
public class DiscoverOperatorsHandler : IMarkHandler
{
LinkContext _context;
bool _seenLinqExpressions;
readonly HashSet<TypeDefinition> _trackedTypesWithOperators;
Dictionary<TypeDefinition, List<MethodDefinition>> _pendingOperatorsForType;

Dictionary<TypeDefinition, List<MethodDefinition>> PendingOperatorsForType {
get {
if (_pendingOperatorsForType == null)
_pendingOperatorsForType = new Dictionary<TypeDefinition, List<MethodDefinition>> ();
return _pendingOperatorsForType;
}
}

public DiscoverOperatorsHandler ()
{
_trackedTypesWithOperators = new HashSet<TypeDefinition> ();
}

public void Initialize (LinkContext context, MarkContext markContext)
{
_context = context;
markContext.RegisterMarkTypeAction (ProcessType);
}

void ProcessType (TypeDefinition type)
{
CheckForLinqExpressions (type);

// Check for custom operators and either:
// - mark them, if Linq.Expressions was already marked, or
// - track them to be marked in case Linq.Expressions is marked later
var hasOperators = ProcessCustomOperators (type, mark: _seenLinqExpressions);
if (!_seenLinqExpressions) {
if (hasOperators)
_trackedTypesWithOperators.Add (type);
return;
}

// Mark pending operators defined on other types that reference this type
// (these are only tracked if we have already seen Linq.Expressions)
if (PendingOperatorsForType.TryGetValue (type, out var pendingOperators)) {
foreach (var customOperator in pendingOperators)
MarkOperator (customOperator);
PendingOperatorsForType.Remove (type);
}
}

void CheckForLinqExpressions (TypeDefinition type)
{
if (_seenLinqExpressions)
return;

if (type.Namespace != "System.Linq.Expressions" || type.Name != "Expression")
return;

_seenLinqExpressions = true;

foreach (var markedType in _trackedTypesWithOperators)
ProcessCustomOperators (markedType, mark: true);

_trackedTypesWithOperators.Clear ();
}

void MarkOperator (MethodDefinition method)
{
_context.Annotations.Mark (method, new DependencyInfo (DependencyKind.PreservedOperator, method.DeclaringType));
}

bool ProcessCustomOperators (TypeDefinition type, bool mark)
{
if (!type.HasMethods)
return false;

bool hasCustomOperators = false;
foreach (var method in type.Methods) {
if (!IsOperator (method, out var otherType))
continue;

if (!mark)
return true;

Debug.Assert (_seenLinqExpressions);
hasCustomOperators = true;

if (otherType == null || _context.Annotations.IsMarked (otherType)) {
MarkOperator (method);
continue;
}

// Wait until otherType gets marked to mark the operator.
if (!PendingOperatorsForType.TryGetValue (otherType, out var pendingOperators)) {
pendingOperators = new List<MethodDefinition> ();
PendingOperatorsForType.Add (otherType, pendingOperators);
}
pendingOperators.Add (method);
}
return hasCustomOperators;
}

TypeDefinition _nullableOfT;
TypeDefinition NullableOfT {
get {
if (_nullableOfT == null)
_nullableOfT = BCL.FindPredefinedType ("System", "Nullable`1", _context);
return _nullableOfT;
}
}

TypeDefinition NonNullableType (TypeReference type)
{
var typeDef = _context.TryResolve (type);
if (typeDef == null)
return null;

if (!typeDef.IsValueType || typeDef != NullableOfT)
return typeDef;

// Unwrap Nullable<T>
Debug.Assert (typeDef.HasGenericParameters);
var nullableType = type as GenericInstanceType;
Debug.Assert (nullableType != null && nullableType.HasGenericArguments && nullableType.GenericArguments.Count == 1);
return _context.TryResolve (nullableType.GenericArguments[0]);
}

bool IsOperator (MethodDefinition method, out TypeDefinition otherType)
{
otherType = null;

if (!method.IsStatic || !method.IsPublic || !method.IsSpecialName || !method.Name.StartsWith ("op_"))
return false;

var operatorName = method.Name.Substring (3);
var self = method.DeclaringType;

switch (operatorName) {
// Unary operators
case "UnaryPlus":
case "UnaryNegation":
case "LogicalNot":
case "OnesComplement":
case "Increment":
case "Decrement":
case "True":
case "False":
// Parameter type of a unary operator must be the declaring type
if (method.Parameters.Count != 1 || NonNullableType (method.Parameters[0].ParameterType) != self)
return false;
// ++ and -- must return the declaring type
if (operatorName is "Increment" or "Decrement" && NonNullableType (method.ReturnType) != self)
return false;
return true;
// Binary operators
case "Addition":
case "Subtraction":
case "Multiply":
case "Division":
case "Modulus":
case "BitwiseAnd":
case "BitwiseOr":
case "ExclusiveOr":
case "LeftShift":
case "RightShift":
case "Equality":
case "Inequality":
case "LessThan":
case "GreaterThan":
case "LessThanOrEqual":
case "GreaterThanOrEqual":
if (method.Parameters.Count != 2)
return false;
var nnLeft = NonNullableType (method.Parameters[0].ParameterType);
var nnRight = NonNullableType (method.Parameters[1].ParameterType);
if (nnLeft == null || nnRight == null)
return false;
// << and >> must take the declaring type and int
if (operatorName is "LeftShift" or "RightShift" && (nnLeft != self || nnRight.MetadataType != MetadataType.Int32))
return false;
// At least one argument must be the declaring type
if (nnLeft != self && nnRight != self)
return false;
if (nnLeft != self)
otherType = nnLeft;
if (nnRight != self)
otherType = nnRight;
return true;
// Conversion operators
case "Implicit":
case "Explicit":
if (method.Parameters.Count != 1)
return false;
var nnSource = NonNullableType (method.Parameters[0].ParameterType);
var nnTarget = NonNullableType (method.ReturnType);
// Exactly one of source/target must be the declaring type
if (nnSource == self == (nnTarget == self))
return false;
otherType = nnSource == self ? nnTarget : nnSource;
return true;
default:
return false;
}
}
}
}
2 changes: 2 additions & 0 deletions src/linker/Linker/DependencyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ public enum DependencyKind
XmlSerialized = 84, // entry type or member for XML serialization
SerializedRecursiveType = 85, // recursive type kept due to serialization handling
SerializedMember = 86, // field or property kept on a type for serialization

PreservedOperator = 87 // operator method preserved on a type
}

public readonly struct DependencyInfo : IEquatable<DependencyInfo>
Expand Down
9 changes: 9 additions & 0 deletions src/linker/Linker/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,12 @@ protected int SetupContext (ILogger customLogger = null)

continue;

case "--disable-operator-discovery":
if (!GetBoolParam (token, l => context.DisableOperatorDiscovery = l))
return -1;

continue;

case "--ignore-descriptors":
if (!GetBoolParam (token, l => context.IgnoreDescriptors = l))
return -1;
Expand Down Expand Up @@ -732,6 +738,9 @@ protected int SetupContext (ILogger customLogger = null)
if (!context.DisableSerializationDiscovery)
p.MarkHandlers.Add (new DiscoverSerializationHandler ());

if (!context.DisableOperatorDiscovery)
p.MarkHandlers.Add (new DiscoverOperatorsHandler ());

foreach (string custom_step in custom_steps) {
if (!AddCustomStep (p, custom_step))
return -1;
Expand Down
2 changes: 2 additions & 0 deletions src/linker/Linker/LinkContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ public class LinkContext : IMetadataResolver, IDisposable

public bool DisableSerializationDiscovery { get; set; }

public bool DisableOperatorDiscovery { get; set; }

public bool IgnoreDescriptors { get; set; }

public bool IgnoreSubstitutions { get; set; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;

namespace Mono.Linker.Tests.Cases.LinqExpressions
{
[SetupLinkerArgument ("--disable-operator-discovery")]
public class CanDisableOperatorDiscovery
{
public static void Main ()
{
var c = new CustomOperators ();
var expression = typeof (System.Linq.Expressions.Expression);
c = -c;
var t = typeof (TargetType);
}

[KeptMember (".ctor()")]
class CustomOperators
{
[Kept]
public static CustomOperators operator - (CustomOperators c) => null;

public static CustomOperators operator + (CustomOperators c) => null;
public static CustomOperators operator + (CustomOperators left, CustomOperators right) => null;
public static explicit operator TargetType (CustomOperators self) => null;
}

[Kept]
class TargetType { }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;

namespace Mono.Linker.Tests.Cases.LinqExpressions
{
public class CanPreserveCustomOperators
{
public static void Main ()
{
var t = typeof (CustomOperators);
var expression = typeof (System.Linq.Expressions.Expression);

var t3 = typeof (TargetTypeImplicit);
var t4 = typeof (SourceTypeImplicit);
var t5 = typeof (TargetTypeExplicit);
var t6 = typeof (SourceTypeExplicit);
}

class CustomOperators
{
// Unary operators
[Kept]
public static CustomOperators operator + (CustomOperators c) => null;
[Kept]
public static CustomOperators operator - (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ! (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ~ (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ++ (CustomOperators c) => null;
[Kept]
public static CustomOperators operator -- (CustomOperators c) => null;
[Kept]
public static bool operator true (CustomOperators c) => true;
[Kept]
public static bool operator false (CustomOperators c) => true;

// Binary operators
[Kept]
public static CustomOperators operator + (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator - (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator * (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator / (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator % (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator & (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator | (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator ^ (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator << (CustomOperators value, int shift) => null;
[Kept]
public static CustomOperators operator >> (CustomOperators value, int shift) => null;
[Kept]
public static CustomOperators operator == (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator != (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator < (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator > (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator <= (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator >= (CustomOperators left, CustomOperators right) => null;

// conversion operators
[Kept]
public static implicit operator TargetTypeImplicit (CustomOperators self) => null;
[Kept]
public static implicit operator CustomOperators (SourceTypeImplicit other) => null;
[Kept]
public static explicit operator TargetTypeExplicit (CustomOperators self) => null;
[Kept]
public static explicit operator CustomOperators (SourceTypeExplicit other) => null;
}

[Kept]
class TargetTypeImplicit { }
[Kept]
class SourceTypeImplicit { }
[Kept]
class TargetTypeExplicit { }
[Kept]
class SourceTypeExplicit { }
}
}
Loading

0 comments on commit 6b0da00

Please sign in to comment.