Skip to content

Commit

Permalink
Replace a few instances of PtrToStructure with more efficient marshal…
Browse files Browse the repository at this point in the history
…ling
  • Loading branch information
jkotas committed Jun 17, 2022
1 parent cfeab72 commit 57817bd
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 101 deletions.
Expand Up @@ -48,7 +48,7 @@ internal enum CryptOidInfoKeyType : int
CRYPT_OID_INFO_CNG_SIGN_KEY = 6,
}

internal static CRYPT_OID_INFO FindOidInfo(CryptOidInfoKeyType keyType, string key, OidGroup group, bool fallBackToAllGroups)
internal static unsafe CRYPT_OID_INFO FindOidInfo(CryptOidInfoKeyType keyType, string key, OidGroup group, bool fallBackToAllGroups)
{
const OidGroup CRYPT_OID_DISABLE_SEARCH_DS_FLAG = unchecked((OidGroup)0x80000000);
Debug.Assert(key != null);
Expand All @@ -75,28 +75,28 @@ internal static CRYPT_OID_INFO FindOidInfo(CryptOidInfoKeyType keyType, string k
if (!OidGroupWillNotUseActiveDirectory(group))
{
OidGroup localGroup = group | CRYPT_OID_DISABLE_SEARCH_DS_FLAG;
IntPtr localOidInfo = CryptFindOIDInfo(keyType, rawKey, localGroup);
if (localOidInfo != IntPtr.Zero)
CRYPT_OID_INFO* localOidInfo = CryptFindOIDInfo(keyType, rawKey, localGroup);
if (localOidInfo != null)
{
return Marshal.PtrToStructure<CRYPT_OID_INFO>(localOidInfo);
return *(CRYPT_OID_INFO*)localOidInfo;
}
}

// Attempt to query with a specific group, to make try to avoid an AD lookup if possible
IntPtr fullOidInfo = CryptFindOIDInfo(keyType, rawKey, group);
if (fullOidInfo != IntPtr.Zero)
CRYPT_OID_INFO* fullOidInfo = CryptFindOIDInfo(keyType, rawKey, group);
if (fullOidInfo != null)
{
return Marshal.PtrToStructure<CRYPT_OID_INFO>(fullOidInfo);
return *(CRYPT_OID_INFO*)fullOidInfo;
}

if (fallBackToAllGroups && group != OidGroup.All)
{
// Finally, for compatibility with previous runtimes, if we have a group specified retry the
// query with no group
IntPtr allGroupOidInfo = CryptFindOIDInfo(keyType, rawKey, OidGroup.All);
if (allGroupOidInfo != IntPtr.Zero)
CRYPT_OID_INFO* allGroupOidInfo = CryptFindOIDInfo(keyType, rawKey, OidGroup.All);
if (allGroupOidInfo != null)
{
return Marshal.PtrToStructure<CRYPT_OID_INFO>(allGroupOidInfo);
return *(CRYPT_OID_INFO*)allGroupOidInfo;
}
}

Expand Down Expand Up @@ -125,6 +125,6 @@ private static bool OidGroupWillNotUseActiveDirectory(OidGroup group)
}

[LibraryImport(Interop.Libraries.Crypt32)]
private static partial IntPtr CryptFindOIDInfo(CryptOidInfoKeyType dwKeyType, IntPtr pvKey, OidGroup group);
private static unsafe partial CRYPT_OID_INFO* CryptFindOIDInfo(CryptOidInfoKeyType dwKeyType, IntPtr pvKey, OidGroup group);
}
}
Expand Up @@ -13,23 +13,18 @@ internal static partial class Crypt32
/// Version used for a buffer containing a scalar integer (not an IntPtr)
/// </summary>
[LibraryImport(Libraries.Crypt32)]
private static unsafe partial IntPtr CryptFindOIDInfo(CryptOidInfoKeyType dwKeyType, int* pvKey, OidGroup group);
private static unsafe partial CRYPT_OID_INFO* CryptFindOIDInfo(CryptOidInfoKeyType dwKeyType, void* pvKey, OidGroup group);

public static CRYPT_OID_INFO FindAlgIdOidInfo(Interop.BCrypt.ECC_CURVE_ALG_ID_ENUM algId)
public static unsafe CRYPT_OID_INFO FindAlgIdOidInfo(Interop.BCrypt.ECC_CURVE_ALG_ID_ENUM algId)
{
int intAlgId = (int)algId;
IntPtr fullOidInfo;
unsafe
{
fullOidInfo = CryptFindOIDInfo(
CryptOidInfoKeyType.CRYPT_OID_INFO_ALGID_KEY,
&intAlgId,
OidGroup.HashAlgorithm);
}
CRYPT_OID_INFO* fullOidInfo = CryptFindOIDInfo(
CryptOidInfoKeyType.CRYPT_OID_INFO_ALGID_KEY,
&algId,
OidGroup.HashAlgorithm);

if (fullOidInfo != IntPtr.Zero)
if (fullOidInfo != null)
{
return Marshal.PtrToStructure<CRYPT_OID_INFO>(fullOidInfo);
return *(CRYPT_OID_INFO*)fullOidInfo;
}

// Otherwise the lookup failed.
Expand Down
Expand Up @@ -76,24 +76,22 @@ internal IPAddress MarshalIPAddress()
// IP_ADAPTER_WINS_SERVER_ADDRESS
// IP_ADAPTER_GATEWAY_ADDRESS
[StructLayout(LayoutKind.Sequential)]
internal struct IpAdapterAddress
internal unsafe struct IpAdapterAddress
{
internal uint length;
internal AdapterAddressFlags flags;
internal IntPtr next;
internal IpAdapterAddress* next;
internal IpSocketAddress address;

internal static InternalIPAddressCollection MarshalIpAddressCollection(IntPtr ptr)
{
InternalIPAddressCollection addressList = new InternalIPAddressCollection();

while (ptr != IntPtr.Zero)
IpAdapterAddress* pIpAdapterAddress = (IpAdapterAddress*)ptr;
while (pIpAdapterAddress != null)
{
IpAdapterAddress addressStructure = Marshal.PtrToStructure<IpAdapterAddress>(ptr);
IPAddress address = addressStructure.address.MarshalIPAddress();
addressList.InternalAdd(address);

ptr = addressStructure.next;
addressList.InternalAdd(pIpAdapterAddress->address.MarshalIPAddress());
pIpAdapterAddress = pIpAdapterAddress->next;
}

return addressList;
Expand All @@ -103,25 +101,24 @@ internal static IPAddressInformationCollection MarshalIpAddressInformationCollec
{
IPAddressInformationCollection addressList = new IPAddressInformationCollection();

while (ptr != IntPtr.Zero)
IpAdapterAddress* pIpAdapterAddress = (IpAdapterAddress*)ptr;
while (pIpAdapterAddress != null)
{
IpAdapterAddress addressStructure = Marshal.PtrToStructure<IpAdapterAddress>(ptr);
IPAddress address = addressStructure.address.MarshalIPAddress();
addressList.InternalAdd(new SystemIPAddressInformation(address, addressStructure.flags));

ptr = addressStructure.next;
addressList.InternalAdd(new SystemIPAddressInformation(
pIpAdapterAddress->address.MarshalIPAddress(), pIpAdapterAddress->flags));
pIpAdapterAddress = pIpAdapterAddress->next;
}

return addressList;
}
}

[StructLayout(LayoutKind.Sequential)]
internal struct IpAdapterUnicastAddress
internal unsafe struct IpAdapterUnicastAddress
{
internal uint length;
internal AdapterAddressFlags flags;
internal IntPtr next;
internal IpAdapterUnicastAddress* next;
internal IpSocketAddress address;
internal PrefixOrigin prefixOrigin;
internal SuffixOrigin suffixOrigin;
Expand All @@ -132,37 +129,59 @@ internal struct IpAdapterUnicastAddress
internal byte prefixLength;
}

[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
internal struct IpAdapterAddresses
[StructLayout(LayoutKind.Sequential)]
internal unsafe struct IpAdapterAddresses
{
internal const int MAX_ADAPTER_ADDRESS_LENGTH = 8;

internal uint length;
internal uint index;
internal IntPtr next;
internal IpAdapterAddresses* next;

// Needs to be ANSI.
[MarshalAs(UnmanagedType.LPStr)]
internal string AdapterName;
private IntPtr _adapterName; // ANSI string
internal string AdapterName => Marshal.PtrToStringAnsi(_adapterName)!;

internal IntPtr firstUnicastAddress;
internal IntPtr firstAnycastAddress;
internal IntPtr firstMulticastAddress;
internal IntPtr firstDnsServerAddress;

internal string dnsSuffix;
internal string description;
internal string friendlyName;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = MAX_ADAPTER_ADDRESS_LENGTH)]
internal byte[] address;
internal uint addressLength;
private IntPtr _dnsSuffix;
internal string DnsSuffix => Marshal.PtrToStringUni(_dnsSuffix)!;

private IntPtr _description;
internal string Description => Marshal.PtrToStringUni(_description)!;

private IntPtr _friendlyName;
internal string FriendlyName => Marshal.PtrToStringUni(_friendlyName)!;

private fixed byte _address[MAX_ADAPTER_ADDRESS_LENGTH];
private uint _addressLength;
internal byte[] Address
{
get
{
fixed (byte* pAddress = _address)
return new ReadOnlySpan<byte>(pAddress, (int)_addressLength).ToArray();
}
}

internal AdapterFlags flags;
internal uint mtu;
internal NetworkInterfaceType type;
internal OperationalStatus operStatus;
internal uint ipv6Index;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 16)]
internal uint[] zoneIndices;

private fixed uint _zoneIndices[16];
internal uint[] ZoneIndices
{
get
{
fixed (uint* pZoneIndices = _zoneIndices)
return new ReadOnlySpan<uint>(pZoneIndices, 16).ToArray();
}
}

internal IntPtr firstPrefix;

internal ulong transmitLinkSpeed;
Expand All @@ -174,13 +193,11 @@ internal struct IpAdapterAddresses
internal ulong luid;
internal IpSocketAddress dhcpv4Server;
internal uint compartmentId;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 16)]
internal byte[] networkGuid;
internal fixed byte networkGuid[16];
internal InterfaceConnectionType connectionType;
internal InterfaceTunnelType tunnelType;
internal IpSocketAddress dhcpv6Server; // Never available in Windows.
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 130)]
internal byte[] dhcpv6ClientDuid;
internal fixed byte dhcpv6ClientDuid[130];
internal uint dhcpv6ClientDuidLength;
internal uint dhcpV6Iaid;

Expand Down Expand Up @@ -211,11 +228,11 @@ internal enum InterfaceTunnelType : int
/// <summary>
/// IP_PER_ADAPTER_INFO - per-adapter IP information such as DNS server list.
/// </summary>
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)]
[StructLayout(LayoutKind.Sequential)]
internal struct IpPerAdapterInfo
{
internal bool autoconfigEnabled;
internal bool autoconfigActive;
internal uint autoconfigEnabled;
internal uint autoconfigActive;
internal IntPtr currentDnsServer; /* IpAddressList* */
internal IpAddrString dnsServerList;
};
Expand All @@ -224,14 +241,12 @@ internal struct IpPerAdapterInfo
/// Store an IP address with its corresponding subnet mask,
/// both as dotted decimal strings.
/// </summary>
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)]
internal struct IpAddrString
[StructLayout(LayoutKind.Sequential)]
internal unsafe struct IpAddrString
{
internal IntPtr Next; /* struct _IpAddressList* */
[MarshalAs(UnmanagedType.ByValTStr, SizeConst = 16)]
internal string IpAddress;
[MarshalAs(UnmanagedType.ByValTStr, SizeConst = 16)]
internal string IpMask;
internal IpAddrString* Next; /* struct _IpAddressList* */
internal fixed byte IpAddress[16];
internal fixed byte IpMask[16];
internal uint Context;
}

Expand Down
Expand Up @@ -68,11 +68,11 @@ private static string ConvertTo8DigitHex(uint value)
return value.ToString("X8");
}

private static Interop.Version.VS_FIXEDFILEINFO GetFixedFileInfo(IntPtr memPtr)
private static unsafe Interop.Version.VS_FIXEDFILEINFO GetFixedFileInfo(IntPtr memPtr)
{
if (Interop.Version.VerQueryValue(memPtr, "\\", out IntPtr memRef, out _))
{
return (Interop.Version.VS_FIXEDFILEINFO)Marshal.PtrToStructure<Interop.Version.VS_FIXEDFILEINFO>(memRef);
return *(Interop.Version.VS_FIXEDFILEINFO*)memRef;
}

return default;
Expand Down
Expand Up @@ -29,7 +29,7 @@ internal sealed class SystemIPInterfaceProperties : IPInterfaceProperties
internal SystemIPInterfaceProperties(in Interop.IpHlpApi.FIXED_INFO fixedInfo, in Interop.IpHlpApi.IpAdapterAddresses ipAdapterAddresses)
{
_adapterFlags = ipAdapterAddresses.flags;
_dnsSuffix = ipAdapterAddresses.dnsSuffix;
_dnsSuffix = ipAdapterAddresses.DnsSuffix;
_dnsEnabled = fixedInfo.enableDns;
_dynamicDnsEnabled = ((ipAdapterAddresses.flags & Interop.IpHlpApi.AdapterFlags.DnsEnabled) > 0);

Expand Down Expand Up @@ -64,7 +64,7 @@ internal SystemIPInterfaceProperties(in Interop.IpHlpApi.FIXED_INFO fixedInfo, i
if ((_adapterFlags & Interop.IpHlpApi.AdapterFlags.IPv6Enabled) != 0)
{
_ipv6Properties = new SystemIPv6InterfaceProperties(ipAdapterAddresses.ipv6Index,
ipAdapterAddresses.mtu, ipAdapterAddresses.zoneIndices);
ipAdapterAddresses.mtu, ipAdapterAddresses.ZoneIndices);
}
}

Expand Down
Expand Up @@ -91,11 +91,10 @@ private unsafe void GetPerAdapterInfo(uint index)
result = Interop.IpHlpApi.GetPerAdapterInfo(index, buffer, &size);
if (result == Interop.IpHlpApi.ERROR_SUCCESS)
{
Interop.IpHlpApi.IpPerAdapterInfo ipPerAdapterInfo =
Marshal.PtrToStructure<Interop.IpHlpApi.IpPerAdapterInfo>(buffer);
Interop.IpHlpApi.IpPerAdapterInfo* ipPerAdapterInfo = (Interop.IpHlpApi.IpPerAdapterInfo*)buffer;

_autoConfigEnabled = ipPerAdapterInfo.autoconfigEnabled;
_autoConfigActive = ipPerAdapterInfo.autoconfigActive;
_autoConfigEnabled = ipPerAdapterInfo->autoconfigEnabled != 0;
_autoConfigActive = ipPerAdapterInfo->autoconfigActive != 0;
}
}
finally
Expand Down
Expand Up @@ -15,7 +15,6 @@ internal sealed class SystemNetworkInterface : NetworkInterface
private readonly string _id;
private readonly string _description;
private readonly byte[] _physicalAddress;
private readonly uint _addressLength;
private readonly NetworkInterfaceType _type;
private readonly OperationalStatus _operStatus;
private readonly long _speed;
Expand Down Expand Up @@ -110,14 +109,12 @@ internal static unsafe NetworkInterface[] GetNetworkInterfaces()
if (result == Interop.IpHlpApi.ERROR_SUCCESS)
{
// Linked list of interfaces.
IntPtr ptr = buffer;
while (ptr != IntPtr.Zero)
Interop.IpHlpApi.IpAdapterAddresses* adapterAddresses = (Interop.IpHlpApi.IpAdapterAddresses*)buffer;
while (adapterAddresses != null)
{
// Traverse the list, marshal in the native structures, and create new NetworkInterfaces.
Interop.IpHlpApi.IpAdapterAddresses adapterAddresses = Marshal.PtrToStructure<Interop.IpHlpApi.IpAdapterAddresses>(ptr);
interfaceList.Add(new SystemNetworkInterface(in fixedInfo, in adapterAddresses));

ptr = adapterAddresses.next;
interfaceList.Add(new SystemNetworkInterface(in fixedInfo, in *adapterAddresses));
adapterAddresses = adapterAddresses->next;
}
}
}
Expand Down Expand Up @@ -146,12 +143,11 @@ internal SystemNetworkInterface(in Interop.IpHlpApi.FIXED_INFO fixedInfo, in Int
{
// Store the common API information.
_id = ipAdapterAddresses.AdapterName;
_name = ipAdapterAddresses.friendlyName;
_description = ipAdapterAddresses.description;
_name = ipAdapterAddresses.FriendlyName;
_description = ipAdapterAddresses.Description;
_index = ipAdapterAddresses.index;

_physicalAddress = ipAdapterAddresses.address;
_addressLength = ipAdapterAddresses.addressLength;
_physicalAddress = ipAdapterAddresses.Address;

_type = ipAdapterAddresses.type;
_operStatus = ipAdapterAddresses.operStatus;
Expand All @@ -172,12 +168,7 @@ internal SystemNetworkInterface(in Interop.IpHlpApi.FIXED_INFO fixedInfo, in Int

public override PhysicalAddress GetPhysicalAddress()
{
byte[] newAddr = new byte[_addressLength];

// Buffer.BlockCopy only supports int while addressLength is uint (see IpAdapterAddresses).
// Will throw OverflowException if addressLength > Int32.MaxValue.
Buffer.BlockCopy(_physicalAddress, 0, newAddr, 0, checked((int)_addressLength));
return new PhysicalAddress(newAddr);
return new PhysicalAddress(_physicalAddress);
}

public override NetworkInterfaceType NetworkInterfaceType { get { return _type; } }
Expand Down

0 comments on commit 57817bd

Please sign in to comment.