Skip to content

Commit

Permalink
Eliminate dead branches around typeof comparisons (#102248)
Browse files Browse the repository at this point in the history
RyuJIT will already do dead branch elimination for `typeof(X) == typeof(Y)` patterns, but we couldn't do elimination around `foo == typeof(X)`. This fixes that using whole program knowledge - if we never saw a constructed `MT` for `X`, the comparison is not going to be true. Because it needs whole program, we still scan this dead branch so in the end this doesn't save much. We can eventually do better.

I'm doing this in `SubstitutedILProvider` instead of in RyuJIT: this is because we currently only reap a small benefit from this optimization due to it only happening during compilation phase. We need to do this during scanning as well. I think I can extend it to scannig. But the extension will require the optimization to 100% guaranteed happen during codegen. We cannot rely on whether RyuJIT will feel like it. `SubstitutedILProvider` is our way to ensure the optimization will happen no matter what - the IL from the branch will be gone and RyuJIT can at most remove the comparison (we don't mind much if it's left).
  • Loading branch information
MichalStrehovsky committed Jun 19, 2024
1 parent a5c1c9f commit e0bd776
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ public bool CanInline(MethodDesc caller, MethodDesc callee)

public bool CanReferenceConstructedMethodTable(TypeDesc type)
{
return NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type);
return NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type.NormalizeInstantiation());
}

public bool CanReferenceConstructedTypeOrCanonicalFormOfType(TypeDesc type)
{
return NodeFactory.DevirtualizationManager.CanReferenceConstructedTypeOrCanonicalFormOfType(type);
return NodeFactory.DevirtualizationManager.CanReferenceConstructedTypeOrCanonicalFormOfType(type.NormalizeInstantiation());
}

public DelegateCreationInfo GetDelegateCtor(TypeDesc delegateType, MethodDesc target, TypeDesc constrainedType, bool followVirtualDispatch)
Expand Down Expand Up @@ -266,9 +266,7 @@ public bool NeedsRuntimeLookup(ReadyToRunHelperId lookupKind, object targetOfLoo

public ReadyToRunHelperId GetLdTokenHelperForType(TypeDesc type)
{
bool canConstructPerWholeProgramAnalysis = NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type);
bool creationAllowed = ConstructedEETypeNode.CreationAllowed(type);
return (canConstructPerWholeProgramAnalysis && creationAllowed)
return (ConstructedEETypeNode.CreationAllowed(type) && NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type.NormalizeInstantiation()))
? ReadyToRunHelperId.TypeHandle
: ReadyToRunHelperId.NecessaryTypeHandle;
}
Expand Down
12 changes: 10 additions & 2 deletions src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/ILScanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -703,10 +703,18 @@ protected override MethodDesc ResolveVirtualMethod(MethodDesc declMethod, DefTyp
}

public override bool CanReferenceConstructedMethodTable(TypeDesc type)
=> _constructedMethodTables.Contains(type);
{
Debug.Assert(type.NormalizeInstantiation() == type);
Debug.Assert(ConstructedEETypeNode.CreationAllowed(type));
return _constructedMethodTables.Contains(type);
}

public override bool CanReferenceConstructedTypeOrCanonicalFormOfType(TypeDesc type)
=> _constructedMethodTables.Contains(type) || _canonConstructedMethodTables.Contains(type);
{
Debug.Assert(type.NormalizeInstantiation() == type);
Debug.Assert(ConstructedEETypeNode.CreationAllowed(type));
return _constructedMethodTables.Contains(type) || _canonConstructedMethodTables.Contains(type);
}

public override TypeDesc[] GetImplementingClasses(TypeDesc type)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ public class SubstitutedILProvider : ILProvider
{
private readonly ILProvider _nestedILProvider;
private readonly SubstitutionProvider _substitutionProvider;
private readonly DevirtualizationManager _devirtualizationManager;

public SubstitutedILProvider(ILProvider nestedILProvider, SubstitutionProvider substitutionProvider)
public SubstitutedILProvider(ILProvider nestedILProvider, SubstitutionProvider substitutionProvider, DevirtualizationManager devirtualizationManager)
{
_nestedILProvider = nestedILProvider;
_substitutionProvider = substitutionProvider;
_devirtualizationManager = devirtualizationManager;
}

public override MethodIL GetMethodIL(MethodDesc method)
Expand Down Expand Up @@ -871,7 +873,26 @@ private static bool TryExpandTypeIs(MethodIL methodIL, byte[] body, OpcodeFlags[
return true;
}

private static bool TryExpandTypeEquality(MethodIL methodIL, byte[] body, OpcodeFlags[] flags, int offset, string op, out int constant)
private bool TryExpandTypeEquality(MethodIL methodIL, byte[] body, OpcodeFlags[] flags, int offset, string op, out int constant)
{
if (TryExpandTypeEquality_TokenToken(methodIL, body, flags, offset, out constant)
|| TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 1, expectGetType: false, out constant)
|| TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 2, expectGetType: false, out constant)
|| TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 3, expectGetType: false, out constant)
|| TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 1, expectGetType: true, out constant)
|| TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 2, expectGetType: true, out constant)
|| TryExpandTypeEquality_TokenOther(methodIL, body, flags, offset, 3, expectGetType: true, out constant))
{
if (op == "op_Inequality")
constant ^= 1;

return true;
}

return false;
}

private static bool TryExpandTypeEquality_TokenToken(MethodIL methodIL, byte[] body, OpcodeFlags[] flags, int offset, out int constant)
{
// We expect to see a sequence:
// ldtoken Foo
Expand Down Expand Up @@ -919,9 +940,108 @@ private static bool TryExpandTypeEquality(MethodIL methodIL, byte[] body, Opcode

constant = equality.Value ? 1 : 0;

if (op == "op_Inequality")
constant ^= 1;
return true;
}

private bool TryExpandTypeEquality_TokenOther(MethodIL methodIL, byte[] body, OpcodeFlags[] flags, int offset, int ldInstructionSize, bool expectGetType, out int constant)
{
// We expect to see a sequence:
// ldtoken Foo
// call GetTypeFromHandle
// ldloc.X/ldloc_s X/ldarg.X/ldarg_s X
// [optional] call Object.GetType
// -> offset points here
//
// The ldtoken part can potentially be in the second argument position

constant = 0;
int sequenceLength = 5 + 5 + ldInstructionSize + (expectGetType ? 5 : 0);
if (offset < sequenceLength)
return false;

if ((flags[offset - sequenceLength] & OpcodeFlags.InstructionStart) == 0)
return false;

ILReader reader = new ILReader(body, offset - sequenceLength);

TypeDesc knownType = null;

// Is the ldtoken in the first position?
if (reader.PeekILOpcode() == ILOpcode.ldtoken)
{
knownType = ReadLdToken(ref reader, methodIL, flags);
if (knownType == null)
return false;

if (!ReadGetTypeFromHandle(ref reader, methodIL, flags))
return false;
}

ILOpcode opcode = reader.ReadILOpcode();
if (ldInstructionSize == 1 && opcode is (>= ILOpcode.ldloc_0 and <= ILOpcode.ldloc_3) or (>= ILOpcode.ldarg_0 and <= ILOpcode.ldarg_3))
{
// Nothing to read
}
else if (ldInstructionSize == 2 && opcode is ILOpcode.ldloc_s or ILOpcode.ldarg_s)
{
reader.ReadILByte();
}
else if (ldInstructionSize == 3 && opcode is ILOpcode.ldloc or ILOpcode.ldarg)
{
reader.ReadILUInt16();
}
else
{
return false;
}

if ((flags[reader.Offset] & OpcodeFlags.BasicBlockStart) != 0)
return false;

if (expectGetType)
{
if (reader.ReadILOpcode() is not ILOpcode.callvirt and not ILOpcode.call)
return false;

// We don't actually mind if this is not Object.GetType
reader.ReadILToken();

if ((flags[reader.Offset] & OpcodeFlags.BasicBlockStart) != 0)
return false;
}

// If the ldtoken wasn't in the first position, it must be in the other
if (knownType == null)
{
knownType = ReadLdToken(ref reader, methodIL, flags);
if (knownType == null)
return false;

if (!ReadGetTypeFromHandle(ref reader, methodIL, flags))
return false;
}

// No value in making this work for definitions
if (knownType.IsGenericDefinition)
return false;

// Dataflow runs on top of uninstantiated IL and we can't answer some questions there.
// Unfortunately this means dataflow will still see code that the rest of the system
// might have optimized away. It should not be a problem in practice.
if (knownType.ContainsSignatureVariables())
return false;

if (knownType.IsCanonicalDefinitionType(CanonicalFormKind.Any))
return false;

// We don't track types without a constructed MethodTable very well.
if (!ConstructedEETypeNode.CreationAllowed(knownType))
return false;

if (_devirtualizationManager.CanReferenceConstructedTypeOrCanonicalFormOfType(knownType.NormalizeInstantiation()))
return false;

constant = 0;
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ public override IEETypeNode NecessaryTypeSymbolIfPossible(TypeDesc type)
// information proving that it isn't, give RyuJIT the constructed symbol even
// though we just need the unconstructed one.
// https://github.com/dotnet/runtimelab/issues/1128
bool canPotentiallyConstruct = NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type);
bool canPotentiallyConstruct = ConstructedEETypeNode.CreationAllowed(type)
&& NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type);
if (canPotentiallyConstruct)
return _nodeFactory.MaximallyConstructableType(type);

Expand All @@ -81,7 +82,8 @@ public override IEETypeNode NecessaryTypeSymbolIfPossible(TypeDesc type)

public FrozenRuntimeTypeNode NecessaryRuntimeTypeIfPossible(TypeDesc type)
{
bool canPotentiallyConstruct = NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type);
bool canPotentiallyConstruct = ConstructedEETypeNode.CreationAllowed(type)
&& NodeFactory.DevirtualizationManager.CanReferenceConstructedMethodTable(type);
if (canPotentiallyConstruct)
return _nodeFactory.SerializedMaximallyConstructableRuntimeTypeObject(type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public ILScanResults Trim (ILCompilerOptions options, TrimmingCustomizations? cu
}

SubstitutionProvider substitutionProvider = new SubstitutionProvider(logger, featureSwitches, substitutions);
ilProvider = new SubstitutedILProvider(ilProvider, substitutionProvider);
ilProvider = new SubstitutedILProvider(ilProvider, substitutionProvider, new DevirtualizationManager());

CompilerGeneratedState compilerGeneratedState = new CompilerGeneratedState (ilProvider, logger);

Expand Down
12 changes: 10 additions & 2 deletions src/coreclr/tools/aot/ILCompiler/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ public int Run()
}

SubstitutionProvider substitutionProvider = new SubstitutionProvider(logger, featureSwitches, substitutions);
ilProvider = new SubstitutedILProvider(ilProvider, substitutionProvider);
ILProvider unsubstitutedILProvider = ilProvider;
ilProvider = new SubstitutedILProvider(ilProvider, substitutionProvider, new DevirtualizationManager());

CompilerGeneratedState compilerGeneratedState = new CompilerGeneratedState(ilProvider, logger);

Expand Down Expand Up @@ -492,10 +493,17 @@ void RunScanner()
if (scanDgmlLogFileName != null)
scanResults.WriteDependencyLog(scanDgmlLogFileName);

DevirtualizationManager devirtualizationManager = scanResults.GetDevirtualizationManager();

metadataManager = ((UsageBasedMetadataManager)metadataManager).ToAnalysisBasedMetadataManager();

interopStubManager = scanResults.GetInteropStubManager(interopStateManager, pinvokePolicy);

ilProvider = new SubstitutedILProvider(unsubstitutedILProvider, substitutionProvider, devirtualizationManager);

// Use a more precise IL provider that uses whole program analysis for dead branch elimination
builder.UseILProvider(ilProvider);

// If we have a scanner, feed the vtable analysis results to the compilation.
// This could be a command line switch if we really wanted to.
builder.UseVTableSliceProvider(scanResults.GetVTableLayoutInfo());
Expand All @@ -507,7 +515,7 @@ void RunScanner()
// If we have a scanner, we can drive devirtualization using the information
// we collected at scanning time (effectively sealing unsealed types if possible).
// This could be a command line switch if we really wanted to.
builder.UseDevirtualizationManager(scanResults.GetDevirtualizationManager());
builder.UseDevirtualizationManager(devirtualizationManager);

// If we use the scanner's result, we need to consult it to drive inlining.
// This prevents e.g. devirtualizing and inlining methods on types that were
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,19 +346,91 @@ sealed class Gen<T> { }

sealed class Never { }

static Type s_type = null;
class Never2 { }
class Canary2 { }
class Never3 { }
class Canary3 { }

class Maybe1<T, U> { }

[MethodImpl(MethodImplOptions.NoInlining)]
static Type GetTheType() => null;

[MethodImpl(MethodImplOptions.NoInlining)]
static Type GetThePointerType() => typeof(void*);

[MethodImpl(MethodImplOptions.NoInlining)]
static object GetTheObject() => new object();

static volatile object s_sink;

public static void Run()
{
// This was asserting the BCL because Never would not have reflection metadata
// despite the typeof
Console.WriteLine(s_type == typeof(Never));
Console.WriteLine(GetTheType() == typeof(Never));

// This was a compiler crash
Console.WriteLine(typeof(object) == typeof(Gen<>));

#if !DEBUG
ThrowIfPresent(typeof(TestTypeEquals), nameof(Never));

{
RunCheck(GetTheType());

static void RunCheck(Type t)
{
if (t == typeof(Never2))
{
s_sink = new Canary2();
}
}

ThrowIfPresentWithUsableMethodTable(typeof(TestTypeEquals), nameof(Canary2));
}

{

RunCheck(GetTheObject());

static void RunCheck(object o)
{
if (o.GetType() == typeof(Never3))
{
s_sink = new Canary3();
}
}

ThrowIfPresentWithUsableMethodTable(typeof(TestTypeEquals), nameof(Canary3));
}

{
RunCheck(GetThePointerType());

static void RunCheck(Type t)
{
if (t == typeof(void*))
{
return;
}
throw new Exception();
}
}

{
RunCheck<object>(typeof(Maybe1<object, string>));

[MethodImpl(MethodImplOptions.NoInlining)]
static void RunCheck<T>(Type t)
{
if (t == typeof(Maybe1<T, string>))
{
return;
}
throw new Exception();
}
}
#endif
}
}
Expand Down

0 comments on commit e0bd776

Please sign in to comment.