diff --git a/src/Compilers/CSharp/Portable/Symbols/Metadata/PE/PENamedTypeSymbol.cs b/src/Compilers/CSharp/Portable/Symbols/Metadata/PE/PENamedTypeSymbol.cs
index e1e7eb35eb7a6..6d4c7c5655851 100644
--- a/src/Compilers/CSharp/Portable/Symbols/Metadata/PE/PENamedTypeSymbol.cs
+++ b/src/Compilers/CSharp/Portable/Symbols/Metadata/PE/PENamedTypeSymbol.cs
@@ -1871,7 +1871,7 @@ public override TypeKind TypeKind
}
// PROTOTYPE consider caching/optimizing this computation
- if (!TryGetExtensionMarkerMethod().IsNil)
+ if (!TryGetExtensionInfo().MarkerMethod.IsNil)
{
// Extension
result = TypeKind.Extension;
@@ -1895,10 +1895,12 @@ public override TypeKind TypeKind
///
/// Superficially checks whether this is a valid extension type
- /// and returns the extension marker method (to be validated later)
+ /// and returns the extension marker method and underlying instance field if applicable
/// if it is.
+ /// Both will be validated later.
///
- private MethodDefinitionHandle TryGetExtensionMarkerMethod()
+ private (MethodDefinitionHandle MarkerMethod, FieldDefinitionHandle UnderlyingInstanceField)
+ TryGetExtensionInfo()
{
var moduleSymbol = this.ContainingPEModule;
var module = moduleSymbol.Module;
@@ -1913,15 +1915,39 @@ private MethodDefinitionHandle TryGetExtensionMarkerMethod()
try
{
- // They must not contain any instance state
+ // The only expected instance state is the underlying instance field
+ FieldDefinitionHandle foundUnderlyingInstanceField = default;
foreach (var field in module.GetFieldsOfTypeOrThrow(_handle))
{
- if ((module.GetFieldDefFlagsOrThrow(field) & FieldAttributes.Static) == 0)
+ if (module.GetFieldDefNameOrThrow(field) == WellKnownMemberNames.ExtensionFieldName)
+ {
+ if ((module.GetFieldDefFlagsOrThrow(field) & FieldAttributes.Static) != 0)
+ {
+ // It must be an instance field
+ return default;
+ }
+
+ if (this.IsStatic)
+ {
+ // It's only allowed in non-static extension types
+ return default;
+ }
+
+ Debug.Assert(foundUnderlyingInstanceField.IsNil);
+ foundUnderlyingInstanceField = field;
+ }
+ else if ((module.GetFieldDefFlagsOrThrow(field) & FieldAttributes.Static) == 0)
{
return default;
}
}
+ if (!this.IsStatic && foundUnderlyingInstanceField.IsNil)
+ {
+ // Non-static extensions must have an underlying instance field (to be validated later)
+ return default;
+ }
+
// They must have a single marker method (to be validated later)
MethodDefinitionHandle foundMarkerMethod = default;
foreach (var methodHandle in module.GetMethodsOfTypeOrThrow(_handle))
@@ -1940,7 +1966,7 @@ private MethodDefinitionHandle TryGetExtensionMarkerMethod()
}
}
- return foundMarkerMethod;
+ return (foundMarkerMethod, foundUnderlyingInstanceField);
}
catch (BadImageFormatException)
{
@@ -1968,10 +1994,11 @@ private void DecodeExtensionType(out bool isExplicit, out TypeSymbol underlyingT
bool tryDecodeExtensionType(out bool isExplicit, [NotNullWhen(true)] out TypeSymbol? underlyingType)
{
- var markerMethod = TryGetExtensionMarkerMethod();
+ var (markerMethod, underlyingInstanceField) = TryGetExtensionInfo();
Debug.Assert(!markerMethod.IsNil);
var moduleSymbol = this.ContainingPEModule;
+ // Decode and validate marker method
isExplicit = false;
underlyingType = null;
@@ -1987,7 +2014,7 @@ bool tryDecodeExtensionType(out bool isExplicit, [NotNullWhen(true)] out TypeSym
}
// PROTOTYPE do we want to tighten the flags check further? (require that type be sealed?)
- if ((localFlags & MethodAttributes.Private) == 0 ||
+ if ((localFlags & MethodAttributes.MemberAccessMask) != MethodAttributes.Private ||
(localFlags & MethodAttributes.Static) == 0)
{
return false;
@@ -2010,7 +2037,9 @@ bool tryDecodeExtensionType(out bool isExplicit, [NotNullWhen(true)] out TypeSym
// PROTOTYPE need to decode extension type references (may be some cycle issues)
type = ApplyTransforms(type, paramInfo.Handle, moduleSymbol);
+ ImmutableArray customModifiers = CSharpCustomModifier.Convert(paramInfo.CustomModifiers);
+ // PROTOTYPE consider checking top-level nullability annotation
if (paramInfo.IsByRef || !paramInfo.CustomModifiers.IsDefault)
{
var info = new CSDiagnosticInfo(ErrorCode.ERR_MalformedExtensionInMetadata, this); // PROTOTYPE need to report use-site diagnostic
@@ -2031,10 +2060,22 @@ bool tryDecodeExtensionType(out bool isExplicit, [NotNullWhen(true)] out TypeSym
{
var info = new CSDiagnosticInfo(ErrorCode.ERR_MalformedExtensionInMetadata, this); // PROTOTYPE need to report use-site diagnostic
underlyingType = new ExtendedErrorTypeSymbol(type, LookupResultKind.NotReferencable, info, unreported: true);
+ return false;
}
else
{
- underlyingType = type;
+ // Validate instance field
+ if (!underlyingInstanceField.IsNil
+ && !validateUnderlyingInstanceField(underlyingInstanceField, moduleSymbol, type))
+ {
+ var info = new CSDiagnosticInfo(ErrorCode.ERR_MalformedExtensionInMetadata, this); // PROTOTYPE need to report use-site diagnostic
+ underlyingType = new ExtendedErrorTypeSymbol(type, LookupResultKind.NotReferencable, info, unreported: true);
+ return false;
+ }
+ else
+ {
+ underlyingType = type;
+ }
}
}
else
@@ -2046,6 +2087,27 @@ bool tryDecodeExtensionType(out bool isExplicit, [NotNullWhen(true)] out TypeSym
Debug.Assert(underlyingType is not null);
return true;
}
+
+ bool validateUnderlyingInstanceField(FieldDefinitionHandle underlyingInstanceFieldHandle, PEModuleSymbol moduleSymbol, TypeSymbol underlyingType)
+ {
+ var fieldSymbol = new PEFieldSymbol(moduleSymbol, this, underlyingInstanceFieldHandle);
+
+ if (fieldSymbol.DeclaredAccessibility != Accessibility.Private
+ || fieldSymbol.IsStatic
+ || fieldSymbol.RefKind != RefKind.None
+ || fieldSymbol.IsReadOnly)
+ {
+ return false;
+ }
+
+ if (!fieldSymbol.TypeWithAnnotations.Equals(TypeWithAnnotations.Create(underlyingType), TypeCompareKind.CLRSignatureCompareOptions))
+ {
+ return false;
+ }
+
+ // PROTOTYPE do we want to tighten the checks further? (required)
+ return true;
+ }
}
#nullable disable
@@ -2186,12 +2248,18 @@ private IEnumerable CreateNestedTypes()
}
}
+ var underlyingInstanceField = TryGetExtensionInfo().UnderlyingInstanceField;
try
{
foreach (var fieldRid in module.GetFieldsOfTypeOrThrow(_handle))
{
try
{
+ if (!underlyingInstanceField.IsNil && fieldRid == underlyingInstanceField)
+ {
+ continue;
+ }
+
if (!(isOrdinaryEmbeddableStruct ||
(isOrdinaryStruct && (module.GetFieldDefFlagsOrThrow(fieldRid) & FieldAttributes.Static) == 0) ||
module.ShouldImportField(fieldRid, moduleSymbol.ImportOptions)))
@@ -2231,7 +2299,7 @@ private IEnumerable CreateNestedTypes()
// PROTOTYPE are extensions embeddable?
// for ordinary embeddable struct types we import private members so that we can report appropriate errors if the structure is used
var isOrdinaryEmbeddableStruct = (this.TypeKind == TypeKind.Struct) && (this.SpecialType == Microsoft.CodeAnalysis.SpecialType.None) && this.ContainingAssembly.IsLinked;
- var extensionMarkerMethod = TryGetExtensionMarkerMethod();
+ var extensionMarkerMethod = TryGetExtensionInfo().MarkerMethod;
try
{
diff --git a/src/Compilers/CSharp/Portable/Symbols/Source/SourceExtensionTypeSymbol.cs b/src/Compilers/CSharp/Portable/Symbols/Source/SourceExtensionTypeSymbol.cs
index 5c43d08fd5c14..f69fcb9368b14 100644
--- a/src/Compilers/CSharp/Portable/Symbols/Source/SourceExtensionTypeSymbol.cs
+++ b/src/Compilers/CSharp/Portable/Symbols/Source/SourceExtensionTypeSymbol.cs
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
@@ -18,6 +19,8 @@ internal sealed class SourceExtensionTypeSymbol : SourceNamedTypeSymbol
private ExtensionInfo _lazyDeclaredExtensionInfo = ExtensionInfo.Sentinel;
// PROTOTYPE consider renaming ExtensionUnderlyingType->ExtendedType (here and elsewhere)
private TypeSymbol? _lazyExtensionUnderlyingType = ErrorTypeSymbol.UnknownResultType;
+ // For non-static extensions, we emit a field of the underlying type
+ private FieldSymbol? _lazyUnderlyingInstanceField = null;
internal SourceExtensionTypeSymbol(NamespaceOrTypeSymbol containingSymbol, MergedTypeDeclaration declaration, BindingDiagnosticBag diagnostics)
: base(containingSymbol, declaration, diagnostics)
@@ -395,5 +398,37 @@ internal static bool IsRestrictedExtensionUnderlyingType(TypeSymbol type)
return false;
}
+
+ private FieldSymbol? UnderlyingInstanceField
+ {
+ get
+ {
+ if (IsStatic)
+ {
+ throw ExceptionUtilities.Unreachable();
+ }
+
+ var extendedType = GetExtendedTypeNoUseSiteDiagnostics(null);
+ if (extendedType is null)
+ {
+ return null;
+ }
+
+ if (_lazyUnderlyingInstanceField is null)
+ {
+ var field = new SynthesizedFieldSymbol(this, extendedType, WellKnownMemberNames.ExtensionFieldName);
+ Interlocked.CompareExchange(ref _lazyUnderlyingInstanceField, field, comparand: null);
+ }
+
+ return _lazyUnderlyingInstanceField;
+ }
+ }
+
+ internal override IEnumerable GetFieldsToEmit()
+ {
+ return !IsStatic && UnderlyingInstanceField is { } underlyingField
+ ? [underlyingField, .. base.GetFieldsToEmit()]
+ : base.GetFieldsToEmit();
+ }
}
}
diff --git a/src/Compilers/CSharp/Test/Emit3/ExtensionTypeTests.cs b/src/Compilers/CSharp/Test/Emit3/ExtensionTypeTests.cs
index 45d2b51c592bc..3d29c76b46f29 100644
--- a/src/Compilers/CSharp/Test/Emit3/ExtensionTypeTests.cs
+++ b/src/Compilers/CSharp/Test/Emit3/ExtensionTypeTests.cs
@@ -8,6 +8,8 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
+using System.Reflection.Metadata;
+using System.Reflection.Metadata.Ecma335;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Symbols.Metadata.PE;
using Microsoft.CodeAnalysis.CSharp.Symbols.Retargeting;
@@ -18,6 +20,7 @@
using Roslyn.Test.Utilities;
using Roslyn.Utilities;
using Xunit;
+using static Roslyn.Test.Utilities.MetadataReaderUtils;
namespace Microsoft.CodeAnalysis.CSharp.UnitTests.Semantics;
@@ -54,7 +57,7 @@ private static void AssertSetStrictlyEqual(string[] expected, string[] actual)
}
// Verify things that are common for all extension types
- private static void VerifyExtension(TypeSymbol type, bool? isExplicit, SpecialType specialType = SpecialType.None) where T : TypeSymbol
+ private static void VerifyExtension(TypeSymbol type, bool? isExplicit, SpecialType specialType = SpecialType.None, string fieldType = null) where T : TypeSymbol
{
var namedType = (NamedTypeSymbol)type;
Assert.True(namedType is T);
@@ -127,6 +130,28 @@ private static void AssertSetStrictlyEqual(string[] expected, string[] actual)
Assert.False(sourceNamedType.IsImplicitlyDeclared);
}
+ if (type is PENamedTypeSymbol peType)
+ {
+ var module = (PEModuleSymbol)type.ContainingModule;
+ var reader = module.Module.GetMetadataReader();
+ var fieldDefHandle = reader.GetTypeDefinition(peType.Handle).GetFields()
+ .Where(f => reader.GetString(reader.GetFieldDefinition(f).Name) == WellKnownMemberNames.ExtensionFieldName).SingleOrDefault();
+
+ // Static extensions don't have this field, but non-static extensions have it
+ Assert.Equal(fieldDefHandle.IsNil, type.IsStatic);
+
+ if (!type.IsStatic)
+ {
+ var fieldDef = reader.GetFieldDefinition(fieldDefHandle);
+ var blob = reader.GetBlobReader(fieldDef.Signature);
+ var decoder = new SignatureDecoder(ConstantSignatureVisualizer.Instance, reader, genericContext: null);
+ var fieldTypeDisplay = decoder.DecodeFieldSignature(ref blob);
+
+ // The instance value field has the expected type
+ Assert.Equal(fieldType, fieldTypeDisplay);
+ }
+ }
+
static void checkBaseExtension(NamedTypeSymbol baseExtension)
{
if (baseExtension.IsExtension)
@@ -243,7 +268,7 @@ void validate(ModuleSymbol module)
}
else
{
- VerifyExtension(r, isExplicit: isExplicit);
+ VerifyExtension(r, isExplicit: isExplicit, fieldType: "UnderlyingClass");
}
Assert.Equal("UnderlyingClass", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
@@ -398,6 +423,7 @@ public void ForClass_Metadata_RefStruct(bool isExplicit)
{
IL_0000: ret
}
+ .field private object '{{WellKnownMemberNames.ExtensionFieldName}}'
}
""";
@@ -644,7 +670,7 @@ static void validate(ModuleSymbol module)
}
else
{
- VerifyExtension(r, isExplicit: true);
+ VerifyExtension(r, isExplicit: true, fieldType: "UnderlyingClass");
}
AssertEx.Equal(new[]
@@ -2032,7 +2058,7 @@ static void validate(ModuleSymbol module)
}
else
{
- VerifyExtension(r, isExplicit: true);
+ VerifyExtension(r, isExplicit: true, fieldType: "UnderlyingStruct");
}
Assert.Equal("UnderlyingStruct", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
@@ -2074,7 +2100,7 @@ void validate(ModuleSymbol module)
}
else
{
- VerifyExtension(r, isExplicit: isExplicit);
+ VerifyExtension(r, isExplicit: isExplicit, fieldType: "!0");
}
Assert.Equal("T", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
@@ -2116,7 +2142,7 @@ void validate(ModuleSymbol module)
}
else
{
- VerifyExtension(r, isExplicit: isExplicit);
+ VerifyExtension(r, isExplicit: isExplicit, fieldType: "E");
}
Assert.Equal("E", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
@@ -2184,7 +2210,7 @@ static void validate(ModuleSymbol module)
}
else
{
- VerifyExtension(r, isExplicit: true);
+ VerifyExtension(r, isExplicit: true, fieldType: "C`1{!0}");
}
Assert.Equal("C", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
@@ -2225,7 +2251,7 @@ static void validate(ModuleSymbol module)
}
else
{
- VerifyExtension(r, isExplicit: true);
+ VerifyExtension(r, isExplicit: true, fieldType: "System.ValueTuple`2{Int32, Int32}");
}
Assert.Equal("(System.Int32, System.Int32)", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
@@ -2266,7 +2292,7 @@ static void validate(ModuleSymbol module)
}
else
{
- VerifyExtension(r, isExplicit: true);
+ VerifyExtension(r, isExplicit: true, fieldType: "Int32[]");
}
Assert.Equal("System.Int32[]", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
@@ -3095,7 +3121,7 @@ public void UnderlyingType_NativeInt(bool useImageReference)
static void validate(ModuleSymbol module)
{
var r = module.GlobalNamespace.GetTypeMember("R");
- VerifyExtension(r, isExplicit: true);
+ VerifyExtension(r, isExplicit: true, fieldType: "IntPtr");
Assert.Equal("nint", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
Assert.Empty(r.BaseExtensionsNoUseSiteDiagnostics);
Assert.Empty(r.AllBaseExtensionsNoUseSiteDiagnostics);
@@ -3137,7 +3163,7 @@ public void UnderlyingType_NativeInt_OlderFramework(bool useImageReference)
static void validate(ModuleSymbol module)
{
var r = module.GlobalNamespace.GetTypeMember("R");
- VerifyExtension(r, isExplicit: true);
+ VerifyExtension(r, isExplicit: true, fieldType: "IntPtr");
Assert.Equal("nint", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
Assert.Empty(r.BaseExtensionsNoUseSiteDiagnostics);
Assert.Empty(r.AllBaseExtensionsNoUseSiteDiagnostics);
@@ -3181,7 +3207,7 @@ public class C { }
static void validate(ModuleSymbol module)
{
var r = module.GlobalNamespace.GetTypeMember("R");
- VerifyExtension(r, isExplicit: true);
+ VerifyExtension(r, isExplicit: true, fieldType: "C`1{IntPtr}");
Assert.Equal("C", r.GetExtendedTypeNoUseSiteDiagnostics(null).ToTestDisplayString());
Assert.Empty(r.BaseExtensionsNoUseSiteDiagnostics);
Assert.Empty(r.AllBaseExtensionsNoUseSiteDiagnostics);
@@ -3198,7 +3224,45 @@ public class C { }
var comp = CreateCompilation(src, targetFramework: TargetFramework.Net70);
comp.VerifyDiagnostics();
- CompileAndVerify(comp, symbolValidator: validate, sourceSymbolValidator: validate, verify: Verification.FailsPEVerify);
+ var verifier = CompileAndVerify(comp, symbolValidator: validate, sourceSymbolValidator: validate, verify: Verification.FailsPEVerify);
+ // Note: we don't emit a DynamicAttribute on synthesized field
+ verifier.VerifyTypeIL("R", """
+.class public sequential ansi sealed beforefieldinit R
+ extends [System.Runtime]System.ValueType
+{
+ .custom instance void [System.Runtime]System.ObsoleteAttribute::.ctor(string, bool) = (
+ 01 00 43 45 78 74 65 6e 73 69 6f 6e 20 74 79 70
+ 65 73 20 61 72 65 20 6e 6f 74 20 73 75 70 70 6f
+ 72 74 65 64 20 69 6e 20 74 68 69 73 20 76 65 72
+ 73 69 6f 6e 20 6f 66 20 79 6f 75 72 20 63 6f 6d
+ 70 69 6c 65 72 2e 01 00 00
+ )
+ .custom instance void [System.Runtime]System.Runtime.CompilerServices.CompilerFeatureRequiredAttribute::.ctor(string) = (
+ 01 00 0e 45 78 74 65 6e 73 69 6f 6e 54 79 70 65
+ 73 00 00
+ )
+ // Fields
+ .field private class C`1