Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace a few instances of PtrToStructure with more efficient marshalling #70866

Merged
merged 2 commits into from Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -48,7 +48,7 @@ internal enum CryptOidInfoKeyType : int
CRYPT_OID_INFO_CNG_SIGN_KEY = 6,
}

internal static CRYPT_OID_INFO FindOidInfo(CryptOidInfoKeyType keyType, string key, OidGroup group, bool fallBackToAllGroups)
internal static unsafe CRYPT_OID_INFO FindOidInfo(CryptOidInfoKeyType keyType, string key, OidGroup group, bool fallBackToAllGroups)
{
const OidGroup CRYPT_OID_DISABLE_SEARCH_DS_FLAG = unchecked((OidGroup)0x80000000);
Debug.Assert(key != null);
Expand All @@ -75,28 +75,28 @@ internal static CRYPT_OID_INFO FindOidInfo(CryptOidInfoKeyType keyType, string k
if (!OidGroupWillNotUseActiveDirectory(group))
{
OidGroup localGroup = group | CRYPT_OID_DISABLE_SEARCH_DS_FLAG;
IntPtr localOidInfo = CryptFindOIDInfo(keyType, rawKey, localGroup);
if (localOidInfo != IntPtr.Zero)
CRYPT_OID_INFO* localOidInfo = CryptFindOIDInfo(keyType, rawKey, localGroup);
if (localOidInfo != null)
{
return Marshal.PtrToStructure<CRYPT_OID_INFO>(localOidInfo);
return *(CRYPT_OID_INFO*)localOidInfo;
jkotas marked this conversation as resolved.
Show resolved Hide resolved
}
}

// Attempt to query with a specific group, to make try to avoid an AD lookup if possible
IntPtr fullOidInfo = CryptFindOIDInfo(keyType, rawKey, group);
if (fullOidInfo != IntPtr.Zero)
CRYPT_OID_INFO* fullOidInfo = CryptFindOIDInfo(keyType, rawKey, group);
if (fullOidInfo != null)
{
return Marshal.PtrToStructure<CRYPT_OID_INFO>(fullOidInfo);
return *(CRYPT_OID_INFO*)fullOidInfo;
}

if (fallBackToAllGroups && group != OidGroup.All)
{
// Finally, for compatibility with previous runtimes, if we have a group specified retry the
// query with no group
IntPtr allGroupOidInfo = CryptFindOIDInfo(keyType, rawKey, OidGroup.All);
if (allGroupOidInfo != IntPtr.Zero)
CRYPT_OID_INFO* allGroupOidInfo = CryptFindOIDInfo(keyType, rawKey, OidGroup.All);
if (allGroupOidInfo != null)
{
return Marshal.PtrToStructure<CRYPT_OID_INFO>(allGroupOidInfo);
return *(CRYPT_OID_INFO*)allGroupOidInfo;
}
}

Expand Down Expand Up @@ -125,6 +125,6 @@ private static bool OidGroupWillNotUseActiveDirectory(OidGroup group)
}

[LibraryImport(Interop.Libraries.Crypt32)]
private static partial IntPtr CryptFindOIDInfo(CryptOidInfoKeyType dwKeyType, IntPtr pvKey, OidGroup group);
private static unsafe partial CRYPT_OID_INFO* CryptFindOIDInfo(CryptOidInfoKeyType dwKeyType, IntPtr pvKey, OidGroup group);
}
}
Expand Up @@ -13,23 +13,18 @@ internal static partial class Crypt32
/// Version used for a buffer containing a scalar integer (not an IntPtr)
/// </summary>
[LibraryImport(Libraries.Crypt32)]
private static unsafe partial IntPtr CryptFindOIDInfo(CryptOidInfoKeyType dwKeyType, int* pvKey, OidGroup group);
private static unsafe partial CRYPT_OID_INFO* CryptFindOIDInfo(CryptOidInfoKeyType dwKeyType, void* pvKey, OidGroup group);

public static CRYPT_OID_INFO FindAlgIdOidInfo(Interop.BCrypt.ECC_CURVE_ALG_ID_ENUM algId)
public static unsafe CRYPT_OID_INFO FindAlgIdOidInfo(Interop.BCrypt.ECC_CURVE_ALG_ID_ENUM algId)
{
int intAlgId = (int)algId;
IntPtr fullOidInfo;
unsafe
{
fullOidInfo = CryptFindOIDInfo(
CryptOidInfoKeyType.CRYPT_OID_INFO_ALGID_KEY,
&intAlgId,
OidGroup.HashAlgorithm);
}
CRYPT_OID_INFO* fullOidInfo = CryptFindOIDInfo(
CryptOidInfoKeyType.CRYPT_OID_INFO_ALGID_KEY,
&algId,
OidGroup.HashAlgorithm);

if (fullOidInfo != IntPtr.Zero)
if (fullOidInfo != null)
{
return Marshal.PtrToStructure<CRYPT_OID_INFO>(fullOidInfo);
return *(CRYPT_OID_INFO*)fullOidInfo;
}

// Otherwise the lookup failed.
Expand Down
Expand Up @@ -76,24 +76,22 @@ internal IPAddress MarshalIPAddress()
// IP_ADAPTER_WINS_SERVER_ADDRESS
// IP_ADAPTER_GATEWAY_ADDRESS
[StructLayout(LayoutKind.Sequential)]
internal struct IpAdapterAddress
internal unsafe struct IpAdapterAddress
{
internal uint length;
internal AdapterAddressFlags flags;
internal IntPtr next;
internal IpAdapterAddress* next;
internal IpSocketAddress address;

internal static InternalIPAddressCollection MarshalIpAddressCollection(IntPtr ptr)
{
InternalIPAddressCollection addressList = new InternalIPAddressCollection();

while (ptr != IntPtr.Zero)
IpAdapterAddress* pIpAdapterAddress = (IpAdapterAddress*)ptr;
while (pIpAdapterAddress != null)
{
IpAdapterAddress addressStructure = Marshal.PtrToStructure<IpAdapterAddress>(ptr);
IPAddress address = addressStructure.address.MarshalIPAddress();
addressList.InternalAdd(address);

ptr = addressStructure.next;
addressList.InternalAdd(pIpAdapterAddress->address.MarshalIPAddress());
pIpAdapterAddress = pIpAdapterAddress->next;
}

return addressList;
Expand All @@ -103,25 +101,24 @@ internal static IPAddressInformationCollection MarshalIpAddressInformationCollec
{
IPAddressInformationCollection addressList = new IPAddressInformationCollection();

while (ptr != IntPtr.Zero)
IpAdapterAddress* pIpAdapterAddress = (IpAdapterAddress*)ptr;
while (pIpAdapterAddress != null)
{
IpAdapterAddress addressStructure = Marshal.PtrToStructure<IpAdapterAddress>(ptr);
IPAddress address = addressStructure.address.MarshalIPAddress();
addressList.InternalAdd(new SystemIPAddressInformation(address, addressStructure.flags));

ptr = addressStructure.next;
addressList.InternalAdd(new SystemIPAddressInformation(
pIpAdapterAddress->address.MarshalIPAddress(), pIpAdapterAddress->flags));
pIpAdapterAddress = pIpAdapterAddress->next;
}

return addressList;
}
}

[StructLayout(LayoutKind.Sequential)]
internal struct IpAdapterUnicastAddress
internal unsafe struct IpAdapterUnicastAddress
{
internal uint length;
internal AdapterAddressFlags flags;
internal IntPtr next;
internal IpAdapterUnicastAddress* next;
internal IpSocketAddress address;
internal PrefixOrigin prefixOrigin;
internal SuffixOrigin suffixOrigin;
Expand All @@ -132,37 +129,59 @@ internal struct IpAdapterUnicastAddress
internal byte prefixLength;
}

[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
internal struct IpAdapterAddresses
[StructLayout(LayoutKind.Sequential)]
internal unsafe struct IpAdapterAddresses
{
internal const int MAX_ADAPTER_ADDRESS_LENGTH = 8;

internal uint length;
internal uint index;
internal IntPtr next;
internal IpAdapterAddresses* next;

// Needs to be ANSI.
[MarshalAs(UnmanagedType.LPStr)]
internal string AdapterName;
private IntPtr _adapterName; // ANSI string
internal string AdapterName => Marshal.PtrToStringAnsi(_adapterName)!;

internal IntPtr firstUnicastAddress;
internal IntPtr firstAnycastAddress;
internal IntPtr firstMulticastAddress;
internal IntPtr firstDnsServerAddress;

internal string dnsSuffix;
internal string description;
internal string friendlyName;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = MAX_ADAPTER_ADDRESS_LENGTH)]
internal byte[] address;
internal uint addressLength;
private IntPtr _dnsSuffix;
internal string DnsSuffix => Marshal.PtrToStringUni(_dnsSuffix)!;

private IntPtr _description;
internal string Description => Marshal.PtrToStringUni(_description)!;

private IntPtr _friendlyName;
internal string FriendlyName => Marshal.PtrToStringUni(_friendlyName)!;

private fixed byte _address[MAX_ADAPTER_ADDRESS_LENGTH];
private uint _addressLength;
internal byte[] Address
{
get
{
fixed (byte* pAddress = _address)
return new ReadOnlySpan<byte>(pAddress, (int)_addressLength).ToArray();
}
}

internal AdapterFlags flags;
internal uint mtu;
internal NetworkInterfaceType type;
internal OperationalStatus operStatus;
internal uint ipv6Index;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 16)]
internal uint[] zoneIndices;

private fixed uint _zoneIndices[16];
internal uint[] ZoneIndices
{
get
{
fixed (uint* pZoneIndices = _zoneIndices)
return new ReadOnlySpan<uint>(pZoneIndices, 16).ToArray();
}
}

internal IntPtr firstPrefix;

internal ulong transmitLinkSpeed;
Expand All @@ -174,13 +193,11 @@ internal struct IpAdapterAddresses
internal ulong luid;
internal IpSocketAddress dhcpv4Server;
internal uint compartmentId;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 16)]
internal byte[] networkGuid;
internal fixed byte networkGuid[16];
internal InterfaceConnectionType connectionType;
internal InterfaceTunnelType tunnelType;
internal IpSocketAddress dhcpv6Server; // Never available in Windows.
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 130)]
internal byte[] dhcpv6ClientDuid;
internal fixed byte dhcpv6ClientDuid[130];
internal uint dhcpv6ClientDuidLength;
internal uint dhcpV6Iaid;

Expand Down Expand Up @@ -211,11 +228,11 @@ internal enum InterfaceTunnelType : int
/// <summary>
/// IP_PER_ADAPTER_INFO - per-adapter IP information such as DNS server list.
/// </summary>
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)]
[StructLayout(LayoutKind.Sequential)]
internal struct IpPerAdapterInfo
{
internal bool autoconfigEnabled;
internal bool autoconfigActive;
internal uint autoconfigEnabled;
internal uint autoconfigActive;
internal IntPtr currentDnsServer; /* IpAddressList* */
internal IpAddrString dnsServerList;
};
Expand All @@ -224,14 +241,12 @@ internal struct IpPerAdapterInfo
/// Store an IP address with its corresponding subnet mask,
/// both as dotted decimal strings.
/// </summary>
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)]
internal struct IpAddrString
[StructLayout(LayoutKind.Sequential)]
internal unsafe struct IpAddrString
{
internal IntPtr Next; /* struct _IpAddressList* */
[MarshalAs(UnmanagedType.ByValTStr, SizeConst = 16)]
internal string IpAddress;
[MarshalAs(UnmanagedType.ByValTStr, SizeConst = 16)]
internal string IpMask;
internal IpAddrString* Next; /* struct _IpAddressList* */
internal fixed byte IpAddress[16];
internal fixed byte IpMask[16];
internal uint Context;
}

Expand Down
Expand Up @@ -68,11 +68,11 @@ private static string ConvertTo8DigitHex(uint value)
return value.ToString("X8");
}

private static Interop.Version.VS_FIXEDFILEINFO GetFixedFileInfo(IntPtr memPtr)
private static unsafe Interop.Version.VS_FIXEDFILEINFO GetFixedFileInfo(IntPtr memPtr)
{
if (Interop.Version.VerQueryValue(memPtr, "\\", out IntPtr memRef, out _))
{
return (Interop.Version.VS_FIXEDFILEINFO)Marshal.PtrToStructure<Interop.Version.VS_FIXEDFILEINFO>(memRef);
return *(Interop.Version.VS_FIXEDFILEINFO*)memRef;
}

return default;
Expand Down
Expand Up @@ -29,7 +29,7 @@ internal sealed class SystemIPInterfaceProperties : IPInterfaceProperties
internal SystemIPInterfaceProperties(in Interop.IpHlpApi.FIXED_INFO fixedInfo, in Interop.IpHlpApi.IpAdapterAddresses ipAdapterAddresses)
{
_adapterFlags = ipAdapterAddresses.flags;
_dnsSuffix = ipAdapterAddresses.dnsSuffix;
_dnsSuffix = ipAdapterAddresses.DnsSuffix;
_dnsEnabled = fixedInfo.enableDns;
_dynamicDnsEnabled = ((ipAdapterAddresses.flags & Interop.IpHlpApi.AdapterFlags.DnsEnabled) > 0);

Expand Down Expand Up @@ -64,7 +64,7 @@ internal SystemIPInterfaceProperties(in Interop.IpHlpApi.FIXED_INFO fixedInfo, i
if ((_adapterFlags & Interop.IpHlpApi.AdapterFlags.IPv6Enabled) != 0)
{
_ipv6Properties = new SystemIPv6InterfaceProperties(ipAdapterAddresses.ipv6Index,
ipAdapterAddresses.mtu, ipAdapterAddresses.zoneIndices);
ipAdapterAddresses.mtu, ipAdapterAddresses.ZoneIndices);
}
}

Expand Down
Expand Up @@ -91,11 +91,10 @@ private unsafe void GetPerAdapterInfo(uint index)
result = Interop.IpHlpApi.GetPerAdapterInfo(index, buffer, &size);
if (result == Interop.IpHlpApi.ERROR_SUCCESS)
{
Interop.IpHlpApi.IpPerAdapterInfo ipPerAdapterInfo =
Marshal.PtrToStructure<Interop.IpHlpApi.IpPerAdapterInfo>(buffer);
Interop.IpHlpApi.IpPerAdapterInfo* ipPerAdapterInfo = (Interop.IpHlpApi.IpPerAdapterInfo*)buffer;

_autoConfigEnabled = ipPerAdapterInfo.autoconfigEnabled;
_autoConfigActive = ipPerAdapterInfo.autoconfigActive;
_autoConfigEnabled = ipPerAdapterInfo->autoconfigEnabled != 0;
_autoConfigActive = ipPerAdapterInfo->autoconfigActive != 0;
}
}
finally
Expand Down
Expand Up @@ -15,7 +15,6 @@ internal sealed class SystemNetworkInterface : NetworkInterface
private readonly string _id;
private readonly string _description;
private readonly byte[] _physicalAddress;
private readonly uint _addressLength;
private readonly NetworkInterfaceType _type;
private readonly OperationalStatus _operStatus;
private readonly long _speed;
Expand Down Expand Up @@ -110,14 +109,12 @@ internal static unsafe NetworkInterface[] GetNetworkInterfaces()
if (result == Interop.IpHlpApi.ERROR_SUCCESS)
{
// Linked list of interfaces.
IntPtr ptr = buffer;
while (ptr != IntPtr.Zero)
Interop.IpHlpApi.IpAdapterAddresses* adapterAddresses = (Interop.IpHlpApi.IpAdapterAddresses*)buffer;
while (adapterAddresses != null)
{
// Traverse the list, marshal in the native structures, and create new NetworkInterfaces.
Interop.IpHlpApi.IpAdapterAddresses adapterAddresses = Marshal.PtrToStructure<Interop.IpHlpApi.IpAdapterAddresses>(ptr);
interfaceList.Add(new SystemNetworkInterface(in fixedInfo, in adapterAddresses));

ptr = adapterAddresses.next;
interfaceList.Add(new SystemNetworkInterface(in fixedInfo, in *adapterAddresses));
adapterAddresses = adapterAddresses->next;
}
}
}
Expand Down Expand Up @@ -146,12 +143,11 @@ internal SystemNetworkInterface(in Interop.IpHlpApi.FIXED_INFO fixedInfo, in Int
{
// Store the common API information.
_id = ipAdapterAddresses.AdapterName;
_name = ipAdapterAddresses.friendlyName;
_description = ipAdapterAddresses.description;
_name = ipAdapterAddresses.FriendlyName;
_description = ipAdapterAddresses.Description;
_index = ipAdapterAddresses.index;

_physicalAddress = ipAdapterAddresses.address;
_addressLength = ipAdapterAddresses.addressLength;
_physicalAddress = ipAdapterAddresses.Address;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine; we'll just need to make sure in future code that touches this stuff that we don't inadvertently access Address multiple times, as it's now allocating a new byte[] on every access whereas previously it wasn't.


_type = ipAdapterAddresses.type;
_operStatus = ipAdapterAddresses.operStatus;
Expand All @@ -172,12 +168,7 @@ internal SystemNetworkInterface(in Interop.IpHlpApi.FIXED_INFO fixedInfo, in Int

public override PhysicalAddress GetPhysicalAddress()
{
byte[] newAddr = new byte[_addressLength];

// Buffer.BlockCopy only supports int while addressLength is uint (see IpAdapterAddresses).
// Will throw OverflowException if addressLength > Int32.MaxValue.
Buffer.BlockCopy(_physicalAddress, 0, newAddr, 0, checked((int)_addressLength));
return new PhysicalAddress(newAddr);
return new PhysicalAddress(_physicalAddress);
}

public override NetworkInterfaceType NetworkInterfaceType { get { return _type; } }
Expand Down