Skip to content

Commit

Permalink
Implement MsQuicConfiguration cache (#99371)
Browse files Browse the repository at this point in the history
* Implement MsQuicConfiguration cache

* Fix creds with custom cipher suites

* Polishing

* Dispose discarded handle when racing to add into cache

* Shuffle code around, add AppCtx switch for disabling

* Code review feedback

* Add comments, minor fixes

* Fix failing test on Windows

* Code review feedback

* Apply suggestions from code review

Co-authored-by: Miha Zupan <mihazupan.zupan1@gmail.com>

* Code review changes

---------

Co-authored-by: Miha Zupan <mihazupan.zupan1@gmail.com>
  • Loading branch information
rzikm and MihaZupan committed Mar 21, 2024
1 parent 8288d6a commit e12e2fa
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Collections.ObjectModel;
using System.Security.Authentication;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using Microsoft.Quic;

namespace System.Net.Quic;

internal static partial class MsQuicConfiguration
{
private const int CheckExpiredModulo = 32;

private const string DisableCacheEnvironmentVariable = "DOTNET_SYSTEM_NET_QUIC_DISABLE_CONFIGURATION_CACHE";
private const string DisableCacheCtxSwitch = "System.Net.Quic.DisableConfigurationCache";

internal static bool ConfigurationCacheEnabled { get; } = GetConfigurationCacheEnabled();
private static bool GetConfigurationCacheEnabled()
{
// AppContext switch takes precedence
if (AppContext.TryGetSwitch(DisableCacheCtxSwitch, out bool value))
{
return !value;
}
else
{
// check environment variable
return
Environment.GetEnvironmentVariable(DisableCacheEnvironmentVariable) is string envVar &&
!(envVar == "1" || envVar.Equals("true", StringComparison.OrdinalIgnoreCase));
}
}

private static readonly ConcurrentDictionary<CacheKey, MsQuicConfigurationSafeHandle> s_configurationCache = new();

private readonly struct CacheKey : IEquatable<CacheKey>
{
public readonly List<byte[]> CertificateThumbprints;
public readonly QUIC_CREDENTIAL_FLAGS Flags;
public readonly QUIC_SETTINGS Settings;
public readonly List<SslApplicationProtocol> ApplicationProtocols;
public readonly QUIC_ALLOWED_CIPHER_SUITE_FLAGS AllowedCipherSuites;

public CacheKey(QUIC_SETTINGS settings, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection<X509Certificate2>? intermediates, List<SslApplicationProtocol> alpnProtocols, QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites)
{
CertificateThumbprints = certificate == null ? new List<byte[]>() : new List<byte[]> { certificate.GetCertHash() };

if (intermediates != null)
{
foreach (X509Certificate2 intermediate in intermediates)
{
CertificateThumbprints.Add(intermediate.GetCertHash());
}
}

Flags = flags;
Settings = settings;
// make defensive copy to prevent modification (the list comes from user code)
ApplicationProtocols = new List<SslApplicationProtocol>(alpnProtocols);
AllowedCipherSuites = allowedCipherSuites;
}

public override bool Equals(object? obj) => obj is CacheKey key && Equals(key);

public bool Equals(CacheKey other)
{
if (CertificateThumbprints.Count != other.CertificateThumbprints.Count)
{
return false;
}

for (int i = 0; i < CertificateThumbprints.Count; i++)
{
if (!CertificateThumbprints[i].AsSpan().SequenceEqual(other.CertificateThumbprints[i]))
{
return false;
}
}

if (ApplicationProtocols.Count != other.ApplicationProtocols.Count)
{
return false;
}

for (int i = 0; i < ApplicationProtocols.Count; i++)
{
if (ApplicationProtocols[i] != other.ApplicationProtocols[i])
{
return false;
}
}

return
Flags == other.Flags &&
Settings.Equals(other.Settings) &&
AllowedCipherSuites == other.AllowedCipherSuites;
}

public override int GetHashCode()
{
HashCode hash = default;

foreach (var thumbprint in CertificateThumbprints)
{
hash.AddBytes(thumbprint);
}

hash.Add(Flags);
hash.Add(Settings);

foreach (var protocol in ApplicationProtocols)
{
hash.AddBytes(protocol.Protocol.Span);
}

hash.Add(AllowedCipherSuites);

return hash.ToHashCode();
}
}

private static MsQuicConfigurationSafeHandle GetCachedCredentialOrCreate(QUIC_SETTINGS settings, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection<X509Certificate2>? intermediates, List<SslApplicationProtocol> alpnProtocols, QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites)
{
CacheKey key = new CacheKey(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites);

MsQuicConfigurationSafeHandle? handle;

if (s_configurationCache.TryGetValue(key, out handle) && handle.TryAddRentCount())
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Found cached MsQuicConfiguration: {handle}.");
}
return handle;
}

// if we get here, the handle is either not in the cache, or we lost the race between
// TryAddRentCount on this thread and MarkForDispose on another thread doing cache cleanup.
// In either case, we need to create a new handle.

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"MsQuicConfiguration not found in cache, creating new.");
}

handle = CreateInternal(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites);
handle.TryAddRentCount(); // we are the first renter

MsQuicConfigurationSafeHandle cached;
do
{
cached = s_configurationCache.GetOrAdd(key, handle);
}
// If we get the same handle back, we successfully added it to the cache and we are done.
// If we get a different handle back, we need to increase the rent count.
// If we fail to add the rent count, then the existing/cached handle is in process of
// being removed from the cache and we can try again, eventually either succeeding to add our
// new handle or getting a fresh handle inserted by another thread meanwhile.
while (cached != handle && !cached.TryAddRentCount());

if (cached != handle)
{
// we lost a race with another thread to insert new handle into the cache
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Discarding MsQuicConfiguration {handle} (preferring cached {cached}).");
}

// First dispose decrements the rent count we added before attempting the cache insertion
// and second closes the handle
handle.Dispose();
handle.Dispose();
Debug.Assert(handle.IsClosed);

return cached;
}

// we added a new handle, check if we need to cleanup
var count = s_configurationCache.Count;
if (count % CheckExpiredModulo == 0)
{
// let only one thread perform cleanup at a time
lock (s_configurationCache)
{
// check again, if another thread just cleaned up (and cached count went down) we are unlikely
// to clean anything
if (s_configurationCache.Count >= count)
{
CleanupCache();
}
}
}

return handle;
}

private static void CleanupCache()
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Cleaning up MsQuicConfiguration cache, current size: {s_configurationCache.Count}.");
}

foreach ((CacheKey key, MsQuicConfigurationSafeHandle handle) in s_configurationCache)
{
if (!handle.TryMarkForDispose())
{
// handle in use
continue;
}

// the handle is not in use and has been marked such that no new rents can be added.
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Removing cached MsQuicConfiguration {handle}.");
}

bool removed = s_configurationCache.TryRemove(key, out _);
Debug.Assert(removed);
handle.Dispose();
Debug.Assert(handle.IsClosed);
}

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Cleaning up MsQuicConfiguration cache, new size: {s_configurationCache.Count}.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@

namespace System.Net.Quic;

internal static class MsQuicConfiguration
internal static partial class MsQuicConfiguration
{
private static bool HasPrivateKey(this X509Certificate certificate)
=> certificate is X509Certificate2 certificate2 && certificate2.Handle != IntPtr.Zero && certificate2.HasPrivateKey;

public static MsQuicSafeHandle Create(QuicClientConnectionOptions options)
public static MsQuicConfigurationSafeHandle Create(QuicClientConnectionOptions options)
{
SslClientAuthenticationOptions authenticationOptions = options.ClientAuthenticationOptions;

Expand Down Expand Up @@ -79,7 +79,7 @@ public static MsQuicSafeHandle Create(QuicClientConnectionOptions options)
return Create(options, flags, certificate, intermediates, authenticationOptions.ApplicationProtocols, authenticationOptions.CipherSuitesPolicy, authenticationOptions.EncryptionPolicy);
}

public static MsQuicSafeHandle Create(QuicServerConnectionOptions options, string? targetHost)
public static MsQuicConfigurationSafeHandle Create(QuicServerConnectionOptions options, string? targetHost)
{
SslServerAuthenticationOptions authenticationOptions = options.ServerAuthenticationOptions;

Expand Down Expand Up @@ -117,7 +117,7 @@ public static MsQuicSafeHandle Create(QuicServerConnectionOptions options, strin
return Create(options, flags, certificate, intermediates, authenticationOptions.ApplicationProtocols, authenticationOptions.CipherSuitesPolicy, authenticationOptions.EncryptionPolicy);
}

private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection<X509Certificate2>? intermediates, List<SslApplicationProtocol>? alpnProtocols, CipherSuitesPolicy? cipherSuitesPolicy, EncryptionPolicy encryptionPolicy)
private static MsQuicConfigurationSafeHandle Create(QuicConnectionOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection<X509Certificate2>? intermediates, List<SslApplicationProtocol>? alpnProtocols, CipherSuitesPolicy? cipherSuitesPolicy, EncryptionPolicy encryptionPolicy)
{
// Validate options and SSL parameters.
if (alpnProtocols is null || alpnProtocols.Count <= 0)
Expand Down Expand Up @@ -176,31 +176,51 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI
: 0; // 0 disables the timeout
}

QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites = QUIC_ALLOWED_CIPHER_SUITE_FLAGS.NONE;

if (cipherSuitesPolicy != null)
{
flags |= QUIC_CREDENTIAL_FLAGS.SET_ALLOWED_CIPHER_SUITES;
allowedCipherSuites = CipherSuitePolicyToFlags(cipherSuitesPolicy);
}

if (!MsQuicApi.UsesSChannelBackend)
{
flags |= QUIC_CREDENTIAL_FLAGS.USE_PORTABLE_CERTIFICATES;
}

if (ConfigurationCacheEnabled)
{
return GetCachedCredentialOrCreate(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites);
}

return CreateInternal(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites);
}

private static unsafe MsQuicConfigurationSafeHandle CreateInternal(QUIC_SETTINGS settings, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection<X509Certificate2>? intermediates, List<SslApplicationProtocol> alpnProtocols, QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites)
{
QUIC_HANDLE* handle;

using MsQuicBuffers msquicBuffers = new MsQuicBuffers();
msquicBuffers.Initialize(alpnProtocols, alpnProtocol => alpnProtocol.Protocol);
ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ConfigurationOpen(
MsQuicApi.Api.Registration,
msquicBuffers.Buffers,
(uint)alpnProtocols.Count,
(uint)msquicBuffers.Count,
&settings,
(uint)sizeof(QUIC_SETTINGS),
(void*)IntPtr.Zero,
&handle),
"ConfigurationOpen failed");
MsQuicSafeHandle configurationHandle = new MsQuicSafeHandle(handle, SafeHandleType.Configuration);
MsQuicConfigurationSafeHandle configurationHandle = new MsQuicConfigurationSafeHandle(handle);

try
{
QUIC_CREDENTIAL_CONFIG config = new QUIC_CREDENTIAL_CONFIG { Flags = flags };
config.Flags |= (MsQuicApi.UsesSChannelBackend ? QUIC_CREDENTIAL_FLAGS.NONE : QUIC_CREDENTIAL_FLAGS.USE_PORTABLE_CERTIFICATES);

if (cipherSuitesPolicy != null)
QUIC_CREDENTIAL_CONFIG config = new QUIC_CREDENTIAL_CONFIG
{
config.Flags |= QUIC_CREDENTIAL_FLAGS.SET_ALLOWED_CIPHER_SUITES;
config.AllowedCipherSuites = CipherSuitePolicyToFlags(cipherSuitesPolicy);
}
Flags = flags,
AllowedCipherSuites = allowedCipherSuites
};

int status;
if (certificate is null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Threading;
using Microsoft.Quic;

namespace System.Net.Quic;
Expand Down Expand Up @@ -52,7 +53,8 @@ safeHandleType switch
SafeHandleType.Stream => MsQuicApi.Api.ApiTable->StreamClose,
_ => throw new ArgumentException($"Unexpected value: {safeHandleType}", nameof(safeHandleType))
},
safeHandleType) { }
safeHandleType)
{ }

protected override bool ReleaseHandle()
{
Expand Down Expand Up @@ -142,3 +144,46 @@ protected override unsafe bool ReleaseHandle()
return true;
}
}

internal sealed class MsQuicConfigurationSafeHandle : MsQuicSafeHandle
{
// MsQuicConfiguration handles are cached, so we need to keep track of the
// number of times a handle is rented. Once we decide to dispose the handle,
// we set the _rentCount to -1.
private volatile int _rentCount;

public unsafe MsQuicConfigurationSafeHandle(QUIC_HANDLE* handle)
: base(handle, SafeHandleType.Configuration) { }

public bool TryAddRentCount()
{
int oldCount;

do
{
oldCount = _rentCount;
if (oldCount < 0)
{
// The handle is already disposed.
return false;
}
} while (Interlocked.CompareExchange(ref _rentCount, oldCount + 1, oldCount) != oldCount);

return true;
}

public bool TryMarkForDispose()
{
return Interlocked.CompareExchange(ref _rentCount, -1, 0) == 0;
}

protected override void Dispose(bool disposing)
{
if (Interlocked.Decrement(ref _rentCount) < 0)
{
// _rentCount is 0 if the handle was never rented (e.g. failure during creation),
// and is -1 when evicted from cache.
base.Dispose(disposing);
}
}
}

0 comments on commit e12e2fa

Please sign in to comment.