From 3e5517beb897faf4592d23f036446561da1e5c23 Mon Sep 17 00:00:00 2001 From: Tomas Weinfurt Date: Sun, 8 May 2022 23:32:22 -0700 Subject: [PATCH] refactor SslStream internals (#68678) * refactor SslStream internals * fix validation and certs * update fakes * feedback from review --- .../src/System.Net.Security.csproj | 4 +- .../Net/Security/NetEventSource.Security.cs | 144 +++++------ .../Net/Security/SslAuthenticationOptions.cs | 130 +++++----- .../Net/Security/SslConnectionInfo.Android.cs | 5 +- .../Net/Security/SslConnectionInfo.Linux.cs | 5 +- .../Net/Security/SslConnectionInfo.OSX.cs | 5 +- .../Net/Security/SslConnectionInfo.Unix.cs | 2 +- .../Net/Security/SslConnectionInfo.Windows.cs | 41 ++- .../System/Net/Security/SslConnectionInfo.cs | 6 +- ...ream.Implementation.cs => SslStream.IO.cs} | 96 ++----- ...SecureChannel.cs => SslStream.Protocol.cs} | 113 ++------- .../src/System/Net/Security/SslStream.cs | 240 ++++-------------- .../Net/Security/SslStreamPal.Android.cs | 18 +- .../System/Net/Security/SslStreamPal.OSX.cs | 30 +-- .../System/Net/Security/SslStreamPal.Unix.cs | 32 +-- .../Net/Security/SslStreamPal.Windows.cs | 45 +--- .../Fakes/FakeSslStream.Implementation.cs | 55 ++-- 17 files changed, 347 insertions(+), 624 deletions(-) rename src/libraries/System.Net.Security/src/System/Net/Security/{SslStream.Implementation.cs => SslStream.IO.cs} (92%) rename src/libraries/System.Net.Security/src/System/Net/Security/{SecureChannel.cs => SslStream.Protocol.cs} (93%) diff --git a/src/libraries/System.Net.Security/src/System.Net.Security.csproj b/src/libraries/System.Net.Security/src/System.Net.Security.csproj index 8d40a7e3c467f..d3fd152787f15 100644 --- a/src/libraries/System.Net.Security/src/System.Net.Security.csproj +++ b/src/libraries/System.Net.Security/src/System.Net.Security.csproj @@ -32,10 +32,10 @@ - - + + diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/NetEventSource.Security.cs b/src/libraries/System.Net.Security/src/System/Net/Security/NetEventSource.Security.cs index d1d7a0af41851..684c481ca798b 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/NetEventSource.Security.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/NetEventSource.Security.cs @@ -14,8 +14,7 @@ namespace System.Net [EventSource(Name = "Private.InternalDiagnostics.System.Net.Security", LocalizationResources = "FxResources.System.Net.Security.SR")] internal sealed partial class NetEventSource { - private const int SecureChannelCtorId = NextAvailableEventId; - private const int LocatingPrivateKeyId = SecureChannelCtorId + 1; + private const int LocatingPrivateKeyId = NextAvailableEventId + 1; private const int CertIsType2Id = LocatingPrivateKeyId + 1; private const int FoundCertInStoreId = CertIsType2Id + 1; private const int NotFoundCertInStoreId = FoundCertInStoreId + 1; @@ -90,19 +89,6 @@ public void SslStreamCtor(SslStream sslStream, Stream innerStream) private void SslStreamCtor(string thisOrContextObject, string? localId, string? remoteId) => WriteEvent(SslStreamCtorId, thisOrContextObject, localId, remoteId); - [NonEvent] - public void SecureChannelCtor(SecureChannel secureChannel, SslStream sslStream, string hostname, X509CertificateCollection? clientCertificates, EncryptionPolicy encryptionPolicy) - { - if (IsEnabled()) - { - SecureChannelCtor(IdOf(secureChannel), hostname, GetHashCode(secureChannel), clientCertificates?.Count ?? 0, encryptionPolicy); - } - } - - [Event(SecureChannelCtorId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void SecureChannelCtor(string sslStream, string hostname, int secureChannelHash, int clientCertificatesCount, EncryptionPolicy encryptionPolicy) => - WriteEvent(SecureChannelCtorId, sslStream, hostname, secureChannelHash, clientCertificatesCount, (int)encryptionPolicy); - [NonEvent] public void LocatingPrivateKey(X509Certificate x509Certificate, object instance) { @@ -112,8 +98,8 @@ public void LocatingPrivateKey(X509Certificate x509Certificate, object instance) } } [Event(LocatingPrivateKeyId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void LocatingPrivateKey(string x509Certificate, int secureChannelHash) => - WriteEvent(LocatingPrivateKeyId, x509Certificate, secureChannelHash); + private void LocatingPrivateKey(string x509Certificate, int sslStreamHash) => + WriteEvent(LocatingPrivateKeyId, x509Certificate, sslStreamHash); [NonEvent] public void CertIsType2(object instance) @@ -124,8 +110,8 @@ public void CertIsType2(object instance) } } [Event(CertIsType2Id, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void CertIsType2(int secureChannelHash) => - WriteEvent(CertIsType2Id, secureChannelHash); + private void CertIsType2(int sslStreamHash) => + WriteEvent(CertIsType2Id, sslStreamHash); [NonEvent] public void FoundCertInStore(bool serverMode, object instance) @@ -136,8 +122,8 @@ public void FoundCertInStore(bool serverMode, object instance) } } [Event(FoundCertInStoreId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void FoundCertInStore(string store, int secureChannelHash) => - WriteEvent(FoundCertInStoreId, store, secureChannelHash); + private void FoundCertInStore(string store, int sslStreamHash) => + WriteEvent(FoundCertInStoreId, store, sslStreamHash); [NonEvent] public void NotFoundCertInStore(object instance) @@ -148,8 +134,8 @@ public void NotFoundCertInStore(object instance) } } [Event(NotFoundCertInStoreId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void NotFoundCertInStore(int secureChannelHash) => - WriteEvent(NotFoundCertInStoreId, secureChannelHash); + private void NotFoundCertInStore(int sslStreamHash) => + WriteEvent(NotFoundCertInStoreId, sslStreamHash); [NonEvent] public void RemoteCertificate(X509Certificate? remoteCertificate) @@ -164,124 +150,124 @@ public void RemoteCertificate(X509Certificate? remoteCertificate) WriteEvent(RemoteCertificateId, remoteCertificate); [NonEvent] - public void CertificateFromDelegate(SecureChannel secureChannel) + public void CertificateFromDelegate(SslStream SslStream) { if (IsEnabled()) { - CertificateFromDelegate(GetHashCode(secureChannel)); + CertificateFromDelegate(GetHashCode(SslStream)); } } [Event(CertificateFromDelegateId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void CertificateFromDelegate(int secureChannelHash) => - WriteEvent(CertificateFromDelegateId, secureChannelHash); + private void CertificateFromDelegate(int sslStreamHash) => + WriteEvent(CertificateFromDelegateId, sslStreamHash); [NonEvent] - public void NoDelegateNoClientCert(SecureChannel secureChannel) + public void NoDelegateNoClientCert(SslStream SslStream) { if (IsEnabled()) { - NoDelegateNoClientCert(GetHashCode(secureChannel)); + NoDelegateNoClientCert(GetHashCode(SslStream)); } } [Event(NoDelegateNoClientCertId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void NoDelegateNoClientCert(int secureChannelHash) => - WriteEvent(NoDelegateNoClientCertId, secureChannelHash); + private void NoDelegateNoClientCert(int sslStreamHash) => + WriteEvent(NoDelegateNoClientCertId, sslStreamHash); [NonEvent] - public void NoDelegateButClientCert(SecureChannel secureChannel) + public void NoDelegateButClientCert(SslStream SslStream) { if (IsEnabled()) { - NoDelegateButClientCert(GetHashCode(secureChannel)); + NoDelegateButClientCert(GetHashCode(SslStream)); } } [Event(NoDelegateButClientCertId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void NoDelegateButClientCert(int secureChannelHash) => - WriteEvent(NoDelegateButClientCertId, secureChannelHash); + private void NoDelegateButClientCert(int sslStreamHash) => + WriteEvent(NoDelegateButClientCertId, sslStreamHash); [NonEvent] - public void AttemptingRestartUsingCert(X509Certificate? clientCertificate, SecureChannel secureChannel) + public void AttemptingRestartUsingCert(X509Certificate? clientCertificate, SslStream SslStream) { if (IsEnabled()) { - AttemptingRestartUsingCert(clientCertificate?.ToString(true), GetHashCode(secureChannel)); + AttemptingRestartUsingCert(clientCertificate?.ToString(true), GetHashCode(SslStream)); } } [Event(AttemptingRestartUsingCertId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void AttemptingRestartUsingCert(string? clientCertificate, int secureChannelHash) => - WriteEvent(AttemptingRestartUsingCertId, clientCertificate, secureChannelHash); + private void AttemptingRestartUsingCert(string? clientCertificate, int sslStreamHash) => + WriteEvent(AttemptingRestartUsingCertId, clientCertificate, sslStreamHash); [NonEvent] - public void NoIssuersTryAllCerts(SecureChannel secureChannel) + public void NoIssuersTryAllCerts(SslStream SslStream) { if (IsEnabled()) { - NoIssuersTryAllCerts(GetHashCode(secureChannel)); + NoIssuersTryAllCerts(GetHashCode(SslStream)); } } [Event(NoIssuersTryAllCertsId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void NoIssuersTryAllCerts(int secureChannelHash) => - WriteEvent(NoIssuersTryAllCertsId, secureChannelHash); + private void NoIssuersTryAllCerts(int sslStreamHash) => + WriteEvent(NoIssuersTryAllCertsId, sslStreamHash); [NonEvent] - public void LookForMatchingCerts(int issuersCount, SecureChannel secureChannel) + public void LookForMatchingCerts(int issuersCount, SslStream SslStream) { if (IsEnabled()) { - LookForMatchingCerts(issuersCount, GetHashCode(secureChannel)); + LookForMatchingCerts(issuersCount, GetHashCode(SslStream)); } } [Event(LookForMatchingCertsId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void LookForMatchingCerts(int issuersCount, int secureChannelHash) => - WriteEvent(LookForMatchingCertsId, issuersCount, secureChannelHash); + private void LookForMatchingCerts(int issuersCount, int sslStreamHash) => + WriteEvent(LookForMatchingCertsId, issuersCount, sslStreamHash); [NonEvent] - public void SelectedCert(X509Certificate clientCertificate, SecureChannel secureChannel) + public void SelectedCert(X509Certificate clientCertificate, SslStream SslStream) { if (IsEnabled()) { - SelectedCert(clientCertificate?.ToString(true), GetHashCode(secureChannel)); + SelectedCert(clientCertificate?.ToString(true), GetHashCode(SslStream)); } } [Event(SelectedCertId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void SelectedCert(string? clientCertificate, int secureChannelHash) => - WriteEvent(SelectedCertId, clientCertificate, secureChannelHash); + private void SelectedCert(string? clientCertificate, int sslStreamHash) => + WriteEvent(SelectedCertId, clientCertificate, sslStreamHash); [NonEvent] - public void CertsAfterFiltering(int filteredCertsCount, SecureChannel secureChannel) + public void CertsAfterFiltering(int filteredCertsCount, SslStream SslStream) { if (IsEnabled()) { - CertsAfterFiltering(filteredCertsCount, GetHashCode(secureChannel)); + CertsAfterFiltering(filteredCertsCount, GetHashCode(SslStream)); } } [Event(CertsAfterFilteringId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void CertsAfterFiltering(int filteredCertsCount, int secureChannelHash) => - WriteEvent(CertsAfterFilteringId, filteredCertsCount, secureChannelHash); + private void CertsAfterFiltering(int filteredCertsCount, int sslStreamHash) => + WriteEvent(CertsAfterFilteringId, filteredCertsCount, sslStreamHash); [NonEvent] - public void FindingMatchingCerts(SecureChannel secureChannel) + public void FindingMatchingCerts(SslStream SslStream) { if (IsEnabled()) { - FindingMatchingCerts(GetHashCode(secureChannel)); + FindingMatchingCerts(GetHashCode(SslStream)); } } [Event(FindingMatchingCertsId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void FindingMatchingCerts(int secureChannelHash) => - WriteEvent(FindingMatchingCertsId, secureChannelHash); + private void FindingMatchingCerts(int sslStreamHash) => + WriteEvent(FindingMatchingCertsId, sslStreamHash); [NonEvent] - public void UsingCachedCredential(SecureChannel secureChannel) + public void UsingCachedCredential(SslStream SslStream) { if (IsEnabled()) { - UsingCachedCredential(GetHashCode(secureChannel)); + UsingCachedCredential(GetHashCode(SslStream)); } } [Event(UsingCachedCredentialId, Keywords = Keywords.Default, Level = EventLevel.Informational)] - private void UsingCachedCredential(int secureChannelHash) => - WriteEvent(UsingCachedCredentialId, secureChannelHash); + private void UsingCachedCredential(int sslStreamHash) => + WriteEvent(UsingCachedCredentialId, sslStreamHash); [Event(SspiSelectedCipherSuitId, Keywords = Keywords.Default, Level = EventLevel.Informational)] public void SspiSelectedCipherSuite( @@ -303,52 +289,52 @@ public void UsingCachedCredential(SecureChannel secureChannel) } [NonEvent] - public void RemoteCertificateError(SecureChannel secureChannel, string message) + public void RemoteCertificateError(SslStream SslStream, string message) { if (IsEnabled()) { - RemoteCertificateError(GetHashCode(secureChannel), message); + RemoteCertificateError(GetHashCode(SslStream), message); } } [Event(RemoteCertificateErrorId, Keywords = Keywords.Default, Level = EventLevel.Verbose)] - private void RemoteCertificateError(int secureChannelHash, string message) => - WriteEvent(RemoteCertificateErrorId, secureChannelHash, message); + private void RemoteCertificateError(int sslStreamHash, string message) => + WriteEvent(RemoteCertificateErrorId, sslStreamHash, message); [NonEvent] - public void RemoteCertDeclaredValid(SecureChannel secureChannel) + public void RemoteCertDeclaredValid(SslStream SslStream) { if (IsEnabled()) { - RemoteCertDeclaredValid(GetHashCode(secureChannel)); + RemoteCertDeclaredValid(GetHashCode(SslStream)); } } [Event(RemoteVertificateValidId, Keywords = Keywords.Default, Level = EventLevel.Verbose)] - private void RemoteCertDeclaredValid(int secureChannelHash) => - WriteEvent(RemoteVertificateValidId, secureChannelHash); + private void RemoteCertDeclaredValid(int sslStreamHash) => + WriteEvent(RemoteVertificateValidId, sslStreamHash); [NonEvent] - public void RemoteCertHasNoErrors(SecureChannel secureChannel) + public void RemoteCertHasNoErrors(SslStream SslStream) { if (IsEnabled()) { - RemoteCertHasNoErrors(GetHashCode(secureChannel)); + RemoteCertHasNoErrors(GetHashCode(SslStream)); } } [Event(RemoteCertificateSuccesId, Keywords = Keywords.Default, Level = EventLevel.Verbose)] - private void RemoteCertHasNoErrors(int secureChannelHash) => - WriteEvent(RemoteCertificateSuccesId, secureChannelHash); + private void RemoteCertHasNoErrors(int sslStreamHash) => + WriteEvent(RemoteCertificateSuccesId, sslStreamHash); [NonEvent] - public void RemoteCertUserDeclaredInvalid(SecureChannel secureChannel) + public void RemoteCertUserDeclaredInvalid(SslStream SslStream) { if (IsEnabled()) { - RemoteCertUserDeclaredInvalid(GetHashCode(secureChannel)); + RemoteCertUserDeclaredInvalid(GetHashCode(SslStream)); } } [Event(RemoteCertificateInvalidId, Keywords = Keywords.Default, Level = EventLevel.Verbose)] - private void RemoteCertUserDeclaredInvalid(int secureChannelHash) => - WriteEvent(RemoteCertificateInvalidId, secureChannelHash); + private void RemoteCertUserDeclaredInvalid(int sslStreamHash) => + WriteEvent(RemoteCertificateInvalidId, sslStreamHash); [NonEvent] public void SentFrame(SslStream sslStream, ReadOnlySpan frame) diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs index 896da65eb110e..af45303e999f6 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs @@ -10,14 +10,39 @@ namespace System.Net.Security { internal sealed class SslAuthenticationOptions { - internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertificateValidationCallback? remoteCallback, LocalCertSelectionCallback? localCallback) + internal SslAuthenticationOptions() + { + TargetHost = string.Empty; + } + + internal void UpdateOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions) { Debug.Assert(sslClientAuthenticationOptions.TargetHost != null); + if (CertValidationDelegate == null) + { + CertValidationDelegate = sslClientAuthenticationOptions.RemoteCertificateValidationCallback; + } + else if (sslClientAuthenticationOptions.RemoteCertificateValidationCallback != null && + CertValidationDelegate != sslClientAuthenticationOptions.RemoteCertificateValidationCallback) + { + // Callback was set in constructor to differet value. + throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(RemoteCertificateValidationCallback))); + } + + if (CertSelectionDelegate == null) + { + CertSelectionDelegate = sslClientAuthenticationOptions.LocalCertificateSelectionCallback; + } + else if (sslClientAuthenticationOptions.LocalCertificateSelectionCallback != null && + CertSelectionDelegate != sslClientAuthenticationOptions.LocalCertificateSelectionCallback) + { + throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(LocalCertificateSelectionCallback))); + } + // Common options. AllowRenegotiation = sslClientAuthenticationOptions.AllowRenegotiation; ApplicationProtocols = sslClientAuthenticationOptions.ApplicationProtocols; - CertValidationDelegate = remoteCallback; CheckCertName = true; EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslClientAuthenticationOptions.EnabledSslProtocols); EncryptionPolicy = sslClientAuthenticationOptions.EncryptionPolicy; @@ -27,32 +52,57 @@ internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthen TargetHost = sslClientAuthenticationOptions.TargetHost.TrimEnd('.'); // Client specific options. - CertSelectionDelegate = localCallback; CertificateRevocationCheckMode = sslClientAuthenticationOptions.CertificateRevocationCheckMode; ClientCertificates = sslClientAuthenticationOptions.ClientCertificates; CipherSuitesPolicy = sslClientAuthenticationOptions.CipherSuitesPolicy; } - internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions) + internal void UpdateOptions(ServerOptionsSelectionCallback optionCallback, object? state) { - // Common options. - AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation; - ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols; CheckCertName = false; - EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols); - EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy; + TargetHost = string.Empty; IsServer = true; - RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired; - if (NetEventSource.Log.IsEnabled()) + UserState = state; + ServerOptionDelegate = optionCallback; + } + + internal void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions) + { + if (sslServerAuthenticationOptions.ServerCertificate == null && + sslServerAuthenticationOptions.ServerCertificateContext == null && + sslServerAuthenticationOptions.ServerCertificateSelectionCallback == null && + CertSelectionDelegate == null) { - NetEventSource.Info(this, $"Server RemoteCertRequired: {RemoteCertRequired}."); + throw new NotSupportedException(SR.net_ssl_io_no_server_cert); + } + + if ((sslServerAuthenticationOptions.ServerCertificate != null || + sslServerAuthenticationOptions.ServerCertificateContext != null || + CertSelectionDelegate != null) && + sslServerAuthenticationOptions.ServerCertificateSelectionCallback != null) + { + throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(ServerCertificateSelectionCallback))); + } + + if (CertValidationDelegate == null) + { + CertValidationDelegate = sslServerAuthenticationOptions.RemoteCertificateValidationCallback; + } + else if (sslServerAuthenticationOptions.RemoteCertificateValidationCallback != null && + CertValidationDelegate != sslServerAuthenticationOptions.RemoteCertificateValidationCallback) + { + // Callback was set in constructor to differet value. + throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(RemoteCertificateValidationCallback))); } - TargetHost = string.Empty; - // Server specific options. + IsServer = true; + AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation; + ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols; + EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols); + EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy; + RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired; CipherSuitesPolicy = sslServerAuthenticationOptions.CipherSuitesPolicy; CertificateRevocationCheckMode = sslServerAuthenticationOptions.CertificateRevocationCheckMode; - if (sslServerAuthenticationOptions.ServerCertificateContext != null) { CertificateContext = sslServerAuthenticationOptions.ServerCertificateContext; @@ -70,7 +120,7 @@ internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthen { // This is legacy fix-up. If the Certificate did not have key, we will search stores and we // will try to find one with matching hash. - certificateWithKey = SecureChannel.FindCertificateWithPrivateKey(this, true, sslServerAuthenticationOptions.ServerCertificate); + certificateWithKey = SslStream.FindCertificateWithPrivateKey(this, true, sslServerAuthenticationOptions.ServerCertificate); if (certificateWithKey == null) { throw new AuthenticationException(SR.net_ssl_io_no_server_cert); @@ -80,45 +130,9 @@ internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthen } } - if (sslServerAuthenticationOptions.RemoteCertificateValidationCallback != null) + if (sslServerAuthenticationOptions.ServerCertificateSelectionCallback != null) { - CertValidationDelegate = sslServerAuthenticationOptions.RemoteCertificateValidationCallback; - } - } - - internal SslAuthenticationOptions(ServerOptionsSelectionCallback optionCallback, object? state, RemoteCertificateValidationCallback? remoteCallback) - { - CheckCertName = false; - TargetHost = string.Empty; - IsServer = true; - UserState = state; - ServerOptionDelegate = optionCallback; - CertValidationDelegate = remoteCallback; - } - - internal void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions) - { - AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation; - ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols; - EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols); - EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy; - RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired; - CipherSuitesPolicy = sslServerAuthenticationOptions.CipherSuitesPolicy; - CertificateRevocationCheckMode = sslServerAuthenticationOptions.CertificateRevocationCheckMode; - if (sslServerAuthenticationOptions.ServerCertificateContext != null) - { - CertificateContext = sslServerAuthenticationOptions.ServerCertificateContext; - } - else if (sslServerAuthenticationOptions.ServerCertificate is X509Certificate2 certificateWithKey && - certificateWithKey.HasPrivateKey) - { - // given cert is X509Certificate2 with key. We can use it directly. - CertificateContext = SslStreamCertificateContext.Create(certificateWithKey); - } - - if (sslServerAuthenticationOptions.RemoteCertificateValidationCallback != null) - { - CertValidationDelegate = sslServerAuthenticationOptions.RemoteCertificateValidationCallback; + ServerCertSelectionDelegate = sslServerAuthenticationOptions.ServerCertificateSelectionCallback; } } @@ -150,10 +164,10 @@ private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols proto internal bool RemoteCertRequired { get; set; } internal bool CheckCertName { get; set; } internal RemoteCertificateValidationCallback? CertValidationDelegate { get; set; } - internal LocalCertSelectionCallback? CertSelectionDelegate { get; set; } - internal ServerCertSelectionCallback? ServerCertSelectionDelegate { get; set; } + internal LocalCertificateSelectionCallback? CertSelectionDelegate { get; set; } + internal ServerCertificateSelectionCallback? ServerCertSelectionDelegate { get; set; } internal CipherSuitesPolicy? CipherSuitesPolicy { get; set; } - internal object? UserState { get; } - internal ServerOptionsSelectionCallback? ServerOptionDelegate { get; } + internal object? UserState { get; set; } + internal ServerOptionsSelectionCallback? ServerOptionDelegate { get; set; } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Android.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Android.cs index d8effb9489533..5445860ff5eec 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Android.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Android.cs @@ -7,9 +7,9 @@ namespace System.Net.Security { - internal sealed partial class SslConnectionInfo + internal partial struct SslConnectionInfo { - public SslConnectionInfo(SafeSslHandle sslContext) + public void UpdateSslConnectionInfo(SafeSslHandle sslContext) { string protocolString = Interop.AndroidCrypto.SSLStreamGetProtocol(sslContext); SslProtocols protocol = protocolString switch @@ -26,6 +26,7 @@ public SslConnectionInfo(SafeSslHandle sslContext) _ => SslProtocols.None, }; Protocol = (int)protocol; + ApplicationProtocol = Interop.AndroidCrypto.SSLStreamGetApplicationProtocol(sslContext); // Enum value names should match the cipher suite name, so we just parse the string cipherSuite = Interop.AndroidCrypto.SSLStreamGetCipherSuite(sslContext); diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Linux.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Linux.cs index 6621d88d1092d..92adbb251ccf1 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Linux.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Linux.cs @@ -6,11 +6,12 @@ namespace System.Net.Security { - internal sealed partial class SslConnectionInfo + internal partial struct SslConnectionInfo { - public SslConnectionInfo(SafeSslHandle sslContext) + public void UpdateSslConnectionInfo(SafeSslHandle sslContext) { Protocol = (int)MapProtocolVersion(Interop.Ssl.SslGetVersion(sslContext)); + ApplicationProtocol = Interop.Ssl.SslGetAlpnSelected(sslContext); MapCipherSuite(SslGetCurrentCipherSuite(sslContext)); } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.OSX.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.OSX.cs index 67707d981aa0f..ba883b7a16972 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.OSX.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.OSX.cs @@ -7,9 +7,9 @@ namespace System.Net.Security { - internal sealed partial class SslConnectionInfo + internal partial struct SslConnectionInfo { - public SslConnectionInfo(SafeSslHandle sslContext) + public void UpdateSslConnectionInfo(SafeSslHandle sslContext) { SslProtocols protocol; TlsCipherSuite cipherSuite; @@ -26,6 +26,7 @@ public SslConnectionInfo(SafeSslHandle sslContext) Protocol = (int)protocol; TlsCipherSuite = cipherSuite; + ApplicationProtocol = Interop.AppleCrypto.SslGetAlpnSelected(sslContext); MapCipherSuite(cipherSuite); } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Unix.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Unix.cs index b13ba361ff326..049557b5446ac 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Unix.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Unix.cs @@ -7,7 +7,7 @@ namespace System.Net.Security { - internal sealed partial class SslConnectionInfo + internal partial struct SslConnectionInfo { private void MapCipherSuite(TlsCipherSuite cipherSuite) { diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Windows.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Windows.cs index 3741d4eaa6f12..82f2da4737c4b 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Windows.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.Windows.cs @@ -1,12 +1,47 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; + namespace System.Net.Security { - internal sealed partial class SslConnectionInfo + internal partial struct SslConnectionInfo { - public SslConnectionInfo(SecPkgContext_ConnectionInfo interopConnectionInfo, TlsCipherSuite cipherSuite) + private static byte[]? GetNegotiatedApplicationProtocol(SafeDeleteContext context) + { + Interop.SecPkgContext_ApplicationProtocol alpnContext = default; + bool success = SSPIWrapper.QueryBlittableContextAttributes(GlobalSSPI.SSPISecureChannel, context, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_APPLICATION_PROTOCOL, ref alpnContext); + + // Check if the context returned is alpn data, with successful negotiation. + if (success && + alpnContext.ProtoNegoExt == Interop.ApplicationProtocolNegotiationExt.ALPN && + alpnContext.ProtoNegoStatus == Interop.ApplicationProtocolNegotiationStatus.Success) + { + return alpnContext.Protocol; + } + + return null; + } + + public void UpdateSslConnectionInfo(SafeDeleteContext securityContext) { + SecPkgContext_ConnectionInfo interopConnectionInfo = default; + bool success = SSPIWrapper.QueryBlittableContextAttributes( + GlobalSSPI.SSPISecureChannel, + securityContext, + Interop.SspiCli.ContextAttribute.SECPKG_ATTR_CONNECTION_INFO, + ref interopConnectionInfo); + Debug.Assert(success); + + TlsCipherSuite cipherSuite = default; + SecPkgContext_CipherInfo cipherInfo = default; + + success = SSPIWrapper.QueryBlittableContextAttributes(GlobalSSPI.SSPISecureChannel, securityContext, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_CIPHER_INFO, ref cipherInfo); + if (success) + { + cipherSuite = (TlsCipherSuite)cipherInfo.dwCipherSuite; + } + Protocol = interopConnectionInfo.Protocol; DataCipherAlg = interopConnectionInfo.DataCipherAlg; DataKeySize = interopConnectionInfo.DataKeySize; @@ -16,6 +51,8 @@ public SslConnectionInfo(SecPkgContext_ConnectionInfo interopConnectionInfo, Tls KeyExchKeySize = interopConnectionInfo.KeyExchKeySize; TlsCipherSuite = cipherSuite; + + ApplicationProtocol = GetNegotiatedApplicationProtocol(securityContext); } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.cs index 52a0b00d43bb2..a217e6b77c18e 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslConnectionInfo.cs @@ -3,9 +3,9 @@ namespace System.Net.Security { - internal sealed partial class SslConnectionInfo + internal partial struct SslConnectionInfo { - public int Protocol { get; } + public int Protocol { get; private set; } public TlsCipherSuite TlsCipherSuite { get; private set; } public int DataCipherAlg { get; private set; } public int DataKeySize { get; private set; } @@ -13,5 +13,7 @@ internal sealed partial class SslConnectionInfo public int DataHashKeySize { get; private set; } public int KeyExchangeAlg { get; private set; } public int KeyExchKeySize { get; private set; } + + public byte[]? ApplicationProtocol { get; internal set; } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs similarity index 92% rename from src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs rename to src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs index 37cddae7013f5..23943caaa22ee 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs @@ -15,14 +15,13 @@ namespace System.Net.Security { public partial class SslStream { - private SslAuthenticationOptions? _sslAuthenticationOptions; - + private readonly SslAuthenticationOptions _sslAuthenticationOptions = new SslAuthenticationOptions(); private int _nestedAuth; private bool _isRenego; private TlsFrameHelper.TlsFrameInfo _lastFrame; - private object _handshakeLock => _sslAuthenticationOptions!; + private object _handshakeLock => _sslAuthenticationOptions; private volatile TaskCompletionSource? _handshakeWaiter; private bool _receivedEOF; @@ -33,67 +32,6 @@ public partial class SslStream // 2 = SslStream disposed, connection closed private int _connectionOpenedStatus; - private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertificateValidationCallback? remoteCallback, LocalCertSelectionCallback? localCallback) - { - ThrowIfExceptional(); - - if (_context != null && _context.IsValidContext) - { - throw new InvalidOperationException(SR.net_auth_reauth); - } - - if (_context != null && IsServer) - { - throw new InvalidOperationException(SR.net_auth_client_server); - } - - ArgumentNullException.ThrowIfNull(sslClientAuthenticationOptions.TargetHost, nameof(sslClientAuthenticationOptions.TargetHost)); - - _exception = null; - try - { - _sslAuthenticationOptions = new SslAuthenticationOptions(sslClientAuthenticationOptions, remoteCallback, localCallback); - _context = new SecureChannel(_sslAuthenticationOptions, this); - } - catch (Win32Exception e) - { - throw new AuthenticationException(SR.net_auth_SSPI, e); - } - } - - private void ValidateCreateContext(SslAuthenticationOptions sslAuthenticationOptions) - { - ThrowIfExceptional(); - - if (_context != null && _context.IsValidContext) - { - throw new InvalidOperationException(SR.net_auth_reauth); - } - - if (_context != null && !IsServer) - { - throw new InvalidOperationException(SR.net_auth_client_server); - } - - _exception = null; - _sslAuthenticationOptions = sslAuthenticationOptions; - - try - { - _context = new SecureChannel(_sslAuthenticationOptions, this); - } - catch (Win32Exception e) - { - throw new AuthenticationException(SR.net_auth_SSPI, e); - } - } - - private bool RemoteCertRequired => _context == null || _context.RemoteCertRequired; - - private object? SyncLock => _context; - - private int MaxDataSize => _context!.MaxDataSize; - private void SetException(Exception e) { Debug.Assert(e != null, $"Expected non-null Exception to be passed to {nameof(SetException)}"); @@ -103,7 +41,7 @@ private void SetException(Exception e) _exception = ExceptionDispatchInfo.Capture(e); } - _context?.Close(); + CloseContext(); } // @@ -112,7 +50,7 @@ private void SetException(Exception e) private void CloseInternal() { _exception = s_disposedSentinel; - _context?.Close(); + CloseContext(); // Ensure a Read operation is not in progress, // block potential reads since SslStream is disposing. @@ -152,7 +90,7 @@ private SecurityStatusPal EncryptData(ReadOnlyMemory buffer, ref byte[] ou return new SecurityStatusPal(SecurityStatusPalErrorCode.TryAgain); } - return _context!.Encrypt(buffer, ref outBuffer, out outSize); + return Encrypt(buffer, ref outBuffer, out outSize); } } @@ -171,21 +109,21 @@ private Task ProcessAuthenticationAsync(bool isAsync = false, CancellationToken else { return isAsync ? - ForceAuthenticationAsync(_context!.IsServer, null, cancellationToken) : - ForceAuthenticationAsync(_context!.IsServer, null, cancellationToken); + ForceAuthenticationAsync(IsServer, null, cancellationToken) : + ForceAuthenticationAsync(IsServer, null, cancellationToken); } } private async Task ProcessAuthenticationWithTelemetryAsync(bool isAsync, CancellationToken cancellationToken) { - NetSecurityTelemetry.Log.HandshakeStart(_context!.IsServer, _sslAuthenticationOptions!.TargetHost); + NetSecurityTelemetry.Log.HandshakeStart(IsServer, _sslAuthenticationOptions!.TargetHost); long startingTimestamp = Stopwatch.GetTimestamp(); try { Task task = isAsync? - ForceAuthenticationAsync(_context!.IsServer, null, cancellationToken) : - ForceAuthenticationAsync(_context!.IsServer, null, cancellationToken); + ForceAuthenticationAsync(IsServer, null, cancellationToken) : + ForceAuthenticationAsync(IsServer, null, cancellationToken); await task.ConfigureAwait(false); @@ -197,7 +135,7 @@ private async Task ProcessAuthenticationWithTelemetryAsync(bool isAsync, Cancell } catch (Exception ex) { - NetSecurityTelemetry.Log.HandshakeFailed(_context.IsServer, startingTimestamp, ex.Message); + NetSecurityTelemetry.Log.HandshakeFailed(IsServer, startingTimestamp, ex.Message); throw; } } @@ -250,7 +188,7 @@ private async Task RenegotiateAsync(CancellationToken cancellationTo _isRenego = true; - SecurityStatusPal status = _context!.Renegotiate(out byte[]? nextmsg); + SecurityStatusPal status = Renegotiate(out byte[]? nextmsg); if (nextmsg is { Length: > 0 }) { @@ -319,7 +257,7 @@ private async Task ForceAuthenticationAsync(bool receiveFirst, byte[ { if (!receiveFirst) { - message = _context!.NextMessage(reAuthenticationData); + message = NextMessage(reAuthenticationData); if (message.Size > 0) { await TIOAdapter.WriteAsync(InnerStream, message.Payload!, 0, message.Size, cancellationToken).ConfigureAwait(false); @@ -520,7 +458,7 @@ private ProtocolToken ProcessBlob(int frameSize) _buffer.DiscardEncrypted(frameSize); } - return _context!.NextMessage(availableData.Slice(0, chunkSize)); + return NextMessage(availableData.Slice(0, chunkSize)); } // @@ -554,7 +492,7 @@ private void SendAuthResetSignal(ProtocolToken? message, ExceptionDispatchInfo e // private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus) { - _context!.ProcessHandshakeSuccess(); + ProcessHandshakeSuccess(); if (_nestedAuth != 1) { @@ -565,7 +503,7 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError return true; } - if (!_context.VerifyRemoteCertificate(_sslAuthenticationOptions!.CertValidationDelegate, _sslAuthenticationOptions!.CertificateContext?.Trust, ref alertToken, out sslPolicyErrors, out chainStatus)) + if (!VerifyRemoteCertificate(_sslAuthenticationOptions!.CertValidationDelegate, _sslAuthenticationOptions!.CertificateContext?.Trust, ref alertToken, out sslPolicyErrors, out chainStatus)) { _handshakeCompleted = false; return false; @@ -790,7 +728,7 @@ private SecurityStatusPal DecryptData(int frameSize) ThrowIfExceptionalOrNotAuthenticated(); // Decrypt will decrypt in-place and modify these to point to the actual decrypted data, which may be smaller. - status = _context!.Decrypt(_buffer.EncryptedSpanSliced(frameSize), out int decryptedOffset, out int decryptedCount); + status = Decrypt(_buffer.EncryptedSpanSliced(frameSize), out int decryptedOffset, out int decryptedCount); _buffer.OnDecrypted(decryptedOffset, decryptedCount, frameSize); if (status.ErrorCode == SecurityStatusPalErrorCode.Renegotiate) diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs similarity index 93% rename from src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs rename to src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs index 59f2820d014d8..b732dfe4d24a5 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs @@ -13,14 +13,14 @@ namespace System.Net.Security { - // SecureChannel - a wrapper on SSPI based functionality. - // Provides an additional abstraction layer over SSPI for SslStream. - internal sealed class SecureChannel + internal delegate X509Certificate2? SelectClientCertificate(out bool sessionRestartAttempt); + + public partial class SslStream { private SafeFreeCredentials? _credentialsHandle; private SafeDeleteSslContext? _securityContext; - private SslConnectionInfo? _connectionInfo; + private SslConnectionInfo _connectionInfo; private X509Certificate? _selectedClientCertificate; private X509Certificate2? _remoteCertificate; private bool _remoteCertificateExposed; @@ -30,30 +30,13 @@ internal sealed class SecureChannel private int _trailerSize = 16; private int _maxDataSize = 16354; - private bool _refreshCredentialNeeded; - - private readonly SslAuthenticationOptions _sslAuthenticationOptions; - private SslApplicationProtocol _negotiatedApplicationProtocol; + private bool _refreshCredentialNeeded = true; private static readonly Oid s_serverAuthOid = new Oid("1.3.6.1.5.5.7.3.1", "1.3.6.1.5.5.7.3.1"); private static readonly Oid s_clientAuthOid = new Oid("1.3.6.1.5.5.7.3.2", "1.3.6.1.5.5.7.3.2"); - private SslStream? _ssl; - - internal SecureChannel(SslAuthenticationOptions sslAuthenticationOptions, SslStream sslStream) - { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SecureChannelCtor(this, sslStream, sslAuthenticationOptions.TargetHost!, sslAuthenticationOptions.ClientCertificates, sslAuthenticationOptions.EncryptionPolicy); - - SslStreamPal.VerifyPackageInfo(); - Debug.Assert(sslAuthenticationOptions.TargetHost != null, "sslAuthenticationOptions.TargetHost == null"); - - _securityContext = null; - _refreshCredentialNeeded = true; - _sslAuthenticationOptions = sslAuthenticationOptions; - _ssl = sslStream; - } // - // SecureChannel properties + // Protocol properties // // LocalServerCertificate - local certificate for server mode channel // LocalClientCertificate - selected certificated used in the client channel mode otherwise null @@ -85,15 +68,6 @@ internal bool IsRemoteCertificateAvailable } } - internal X509Certificate? RemoteCertificate - { - get - { - _remoteCertificateExposed = true; - return _remoteCertificate; - } - } - internal ChannelBinding? GetChannelBinding(ChannelBindingKind kind) { ChannelBinding? result = null; @@ -105,14 +79,6 @@ internal bool IsRemoteCertificateAvailable return result; } - internal X509RevocationMode CheckCertRevocationStatus - { - get - { - return _sslAuthenticationOptions.CertificateRevocationCheckMode; - } - } - internal int MaxDataSize { get @@ -121,14 +87,6 @@ internal int MaxDataSize } } - internal SslConnectionInfo? ConnectionInfo - { - get - { - return _connectionInfo; - } - } - internal bool IsValidContext { [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -138,27 +96,11 @@ internal bool IsValidContext } } - internal bool IsServer - { - get - { - return _sslAuthenticationOptions.IsServer; - } - } - internal bool RemoteCertRequired { get { - return _sslAuthenticationOptions.RemoteCertRequired; - } - } - - internal SslApplicationProtocol NegotiatedApplicationProtocol - { - get - { - return _negotiatedApplicationProtocol; + return _sslAuthenticationOptions!.RemoteCertRequired; } } @@ -167,7 +109,7 @@ internal void SetRefreshCredentialNeeded() _refreshCredentialNeeded = true; } - internal void Close() + internal void CloseContext() { if (!_remoteCertificateExposed) { @@ -177,7 +119,6 @@ internal void Close() _securityContext?.Dispose(); _credentialsHandle?.Dispose(); - _ssl = null; GC.SuppressFinalize(this); } @@ -297,9 +238,9 @@ private string[] GetRequestCertificateAuthorities() sessionRestartAttempt = false; X509Certificate? clientCertificate = null; // candidate certificate that can come from the user callback or be guessed when targeting a session restart. - X509Certificate2? selectedCert = null; // final selected cert (ensured that it does have private key with it). + X509Certificate2? selectedCert = null; // final selected cert (ensured that it does have private key with it). List? filteredCerts = null; // This is an intermediate client certs collection that try to use if no selectedCert is available yet. - string[] issuers; // This is a list of issuers sent by the server, only valid if we do know what the server cert is. + string[] issuers; // This is a list of issuers sent by the server, only valid if we do know what the server cert is. if (_sslAuthenticationOptions.CertSelectionDelegate != null) { @@ -315,7 +256,7 @@ private string[] GetRequestCertificateAuthorities() { _sslAuthenticationOptions.ClientCertificates = new X509CertificateCollection(); } - clientCertificate = _sslAuthenticationOptions.CertSelectionDelegate(_sslAuthenticationOptions.TargetHost!, _sslAuthenticationOptions.ClientCertificates, remoteCert, issuers); + clientCertificate = _sslAuthenticationOptions.CertSelectionDelegate(this, _sslAuthenticationOptions.TargetHost, _sslAuthenticationOptions.ClientCertificates, remoteCert, issuers); } finally { @@ -639,7 +580,7 @@ private bool AcquireServerCredentials(ref byte[]? thumbPrint) // with .NET Framework), and if neither is set we fall back to using CertificateContext. if (_sslAuthenticationOptions.ServerCertSelectionDelegate != null) { - localCertificate = _sslAuthenticationOptions.ServerCertSelectionDelegate(_sslAuthenticationOptions.TargetHost); + localCertificate = _sslAuthenticationOptions.ServerCertSelectionDelegate(this, _sslAuthenticationOptions.TargetHost); if (localCertificate == null) { if (NetEventSource.Log.IsEnabled()) @@ -655,7 +596,7 @@ private bool AcquireServerCredentials(ref byte[]? thumbPrint) X509CertificateCollection tempCollection = new X509CertificateCollection(); tempCollection.Add(_sslAuthenticationOptions.CertificateContext!.Certificate!); // We pass string.Empty here to maintain strict compatibility with .NET Framework. - localCertificate = _sslAuthenticationOptions.CertSelectionDelegate(string.Empty, tempCollection, null, Array.Empty()); + localCertificate = _sslAuthenticationOptions.CertSelectionDelegate(this, string.Empty, tempCollection, null, Array.Empty()); if (localCertificate == null) { if (NetEventSource.Log.IsEnabled()) @@ -844,7 +785,6 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan inputBuffer, ref byte if (_sslAuthenticationOptions.IsServer) { status = SslStreamPal.AcceptSecurityContext( - this, ref _credentialsHandle!, ref _securityContext, inputBuffer, @@ -854,13 +794,14 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan inputBuffer, ref byte else { status = SslStreamPal.InitializeSecurityContext( - this, ref _credentialsHandle!, ref _securityContext, _sslAuthenticationOptions.TargetHost, inputBuffer, ref result, - _sslAuthenticationOptions); + _sslAuthenticationOptions, + SelectClientCertificate + ); } } while (cachedCreds && _credentialsHandle == null); } @@ -895,7 +836,6 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan inputBuffer, ref byte internal SecurityStatusPal Renegotiate(out byte[]? output) { return SslStreamPal.Renegotiate( - this, ref _credentialsHandle!, ref _securityContext, _sslAuthenticationOptions, @@ -911,13 +851,6 @@ internal SecurityStatusPal Renegotiate(out byte[]? output) --*/ internal void ProcessHandshakeSuccess() { - if (_negotiatedApplicationProtocol == default) - { - // try to get ALPN info unless we already have it. (renegotiation) - byte[]? alpnResult = SslStreamPal.GetNegotiatedApplicationProtocol(_securityContext!); - _negotiatedApplicationProtocol = alpnResult == null ? default : new SslApplicationProtocol(alpnResult, false); - } - SslStreamPal.QueryContextStreamSizes(_securityContext!, out StreamSizes streamSizes); _headerSize = streamSizes.Header; @@ -925,7 +858,7 @@ internal void ProcessHandshakeSuccess() _maxDataSize = checked(streamSizes.MaximumMessage - (_headerSize + _trailerSize)); Debug.Assert(_maxDataSize > 0, "_maxDataSize > 0"); - SslStreamPal.QueryContextConnectionInfo(_securityContext!, out _connectionInfo); + SslStreamPal.QueryContextConnectionInfo(_securityContext!, ref _connectionInfo); } /*++ @@ -1049,13 +982,7 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot if (remoteCertValidationCallback != null) { - object? sender = _ssl; - if (sender == null) - { - throw new ObjectDisposedException(nameof(SslStream)); - } - - success = remoteCertValidationCallback(sender, _remoteCertificate, chain, sslPolicyErrors); + success = remoteCertValidationCallback(this, _remoteCertificate, chain, sslPolicyErrors); } else { @@ -1105,7 +1032,7 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot return success; } - public ProtocolToken? CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain) + private ProtocolToken? CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain) { TlsAlertMessage alertMessage; @@ -1145,7 +1072,7 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot return GenerateAlertToken(); } - public ProtocolToken? CreateShutdownToken() + private ProtocolToken? CreateShutdownToken() { SecurityStatusPal status; status = SslStreamPal.ApplyShutdownToken(ref _credentialsHandle, _securityContext!); diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index f4816110f918d..44bbfee78f288 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -6,7 +6,6 @@ using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; using System.Security.Authentication; -using System.Security.Authentication.ExtendedProtection; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -37,23 +36,11 @@ public enum EncryptionPolicy public delegate ValueTask ServerOptionsSelectionCallback(SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, CancellationToken cancellationToken); - // Internal versions of the above delegates. - internal delegate X509Certificate LocalCertSelectionCallback(string targetHost, X509CertificateCollection localCertificates, X509Certificate2? remoteCertificate, string[] acceptableIssuers); - internal delegate X509Certificate ServerCertSelectionCallback(string? hostName); - public partial class SslStream : AuthenticatedStream { /// Set as the _exception when the instance is disposed. private static readonly ExceptionDispatchInfo s_disposedSentinel = ExceptionDispatchInfo.Capture(new ObjectDisposedException(nameof(SslStream), (string?)null)); - internal RemoteCertificateValidationCallback? _userCertificateValidationCallback; - internal LocalCertificateSelectionCallback? _userCertificateSelectionCallback; - internal ServerCertificateSelectionCallback? _userServerCertificateSelectionCallback; - internal LocalCertSelectionCallback? _certSelectionDelegate; - internal EncryptionPolicy _encryptionPolicy; - - private SecureChannel? _context; - private ExceptionDispatchInfo? _exception; private bool _shutdown; private bool _handshakeCompleted; @@ -183,8 +170,6 @@ public void ReturnBuffer() } } - - private int _nestedWrite; private int _nestedRead; @@ -220,86 +205,13 @@ public SslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificat } #pragma warning restore SYSLIB0040 - _userCertificateValidationCallback = userCertificateValidationCallback; - _userCertificateSelectionCallback = userCertificateSelectionCallback; - _encryptionPolicy = encryptionPolicy; - _certSelectionDelegate = userCertificateSelectionCallback == null ? null : new LocalCertSelectionCallback(UserCertSelectionCallbackWrapper); + _sslAuthenticationOptions.EncryptionPolicy = encryptionPolicy; + _sslAuthenticationOptions.CertValidationDelegate = userCertificateValidationCallback; + _sslAuthenticationOptions.CertSelectionDelegate = userCertificateSelectionCallback; if (NetEventSource.Log.IsEnabled()) NetEventSource.Log.SslStreamCtor(this, innerStream); } - public SslApplicationProtocol NegotiatedApplicationProtocol - { - get - { - if (_context == null) - return default; - - return _context.NegotiatedApplicationProtocol; - } - } - - private void SetAndVerifyValidationCallback(RemoteCertificateValidationCallback? callback) - { - if (_userCertificateValidationCallback == null) - { - _userCertificateValidationCallback = callback; - } - else if (callback != null && _userCertificateValidationCallback != callback) - { - throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(RemoteCertificateValidationCallback))); - } - } - - private void SetAndVerifySelectionCallback(LocalCertificateSelectionCallback? callback) - { - if (_userCertificateSelectionCallback == null) - { - _userCertificateSelectionCallback = callback; - _certSelectionDelegate = _userCertificateSelectionCallback == null ? null : new LocalCertSelectionCallback(UserCertSelectionCallbackWrapper); - } - else if (callback != null && _userCertificateSelectionCallback != callback) - { - throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(LocalCertificateSelectionCallback))); - } - } - - private X509Certificate UserCertSelectionCallbackWrapper(string targetHost, X509CertificateCollection localCertificates, X509Certificate? remoteCertificate, string[] acceptableIssuers) - { - return _userCertificateSelectionCallback!(this, targetHost, localCertificates, remoteCertificate, acceptableIssuers); - } - - private X509Certificate ServerCertSelectionCallbackWrapper(string? targetHost) => _userServerCertificateSelectionCallback!(this, targetHost); - - private SslAuthenticationOptions CreateAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions) - { - if (sslServerAuthenticationOptions.ServerCertificate == null && - sslServerAuthenticationOptions.ServerCertificateContext == null && - sslServerAuthenticationOptions.ServerCertificateSelectionCallback == null && - _certSelectionDelegate == null) - { - throw new ArgumentNullException(nameof(sslServerAuthenticationOptions.ServerCertificate)); - } - - if ((sslServerAuthenticationOptions.ServerCertificate != null || - sslServerAuthenticationOptions.ServerCertificateContext != null || - _certSelectionDelegate != null) && - sslServerAuthenticationOptions.ServerCertificateSelectionCallback != null) - { - throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(ServerCertificateSelectionCallback))); - } - - var authOptions = new SslAuthenticationOptions(sslServerAuthenticationOptions); - - _userServerCertificateSelectionCallback = sslServerAuthenticationOptions.ServerCertificateSelectionCallback; - authOptions.ServerCertSelectionDelegate = _userServerCertificateSelectionCallback == null ? null : new ServerCertSelectionCallback(ServerCertSelectionCallbackWrapper); - - authOptions.CertValidationDelegate = _userCertificateValidationCallback; - authOptions.CertSelectionDelegate = _certSelectionDelegate; - - return authOptions; - } - // // Client side auth. // @@ -325,14 +237,14 @@ public virtual IAsyncResult BeginAuthenticateAsClient(string targetHost, AsyncCa ClientCertificates = clientCertificates, EnabledSslProtocols = enabledSslProtocols, CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - EncryptionPolicy = _encryptionPolicy, + EncryptionPolicy = _sslAuthenticationOptions.EncryptionPolicy, }; return BeginAuthenticateAsClient(options, CancellationToken.None, asyncCallback, asyncState); } internal IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback? asyncCallback, object? asyncState) => - TaskToApm.Begin(AuthenticateAsClientApm(sslClientAuthenticationOptions, cancellationToken)!, asyncCallback, asyncState); + TaskToApm.Begin(AuthenticateAsClientAsync(sslClientAuthenticationOptions, cancellationToken)!, asyncCallback, asyncState); public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult) => TaskToApm.End(asyncResult); @@ -364,14 +276,14 @@ public virtual IAsyncResult BeginAuthenticateAsServer(X509Certificate serverCert ClientCertificateRequired = clientCertificateRequired, EnabledSslProtocols = enabledSslProtocols, CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - EncryptionPolicy = _encryptionPolicy, + EncryptionPolicy = _sslAuthenticationOptions.EncryptionPolicy, }; return BeginAuthenticateAsServer(options, CancellationToken.None, asyncCallback, asyncState); } private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback? asyncCallback, object? asyncState) => - TaskToApm.Begin(AuthenticateAsServerApm(sslServerAuthenticationOptions, cancellationToken)!, asyncCallback, asyncState); + TaskToApm.Begin(AuthenticateAsServerAsync(sslServerAuthenticationOptions, cancellationToken)!, asyncCallback, asyncState); public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult) => TaskToApm.End(asyncResult); @@ -381,8 +293,6 @@ public virtual IAsyncResult BeginAuthenticateAsServer(X509Certificate serverCert public TransportContext TransportContext => new SslStreamContext(this); - internal ChannelBinding? GetChannelBinding(ChannelBindingKind kind) => _context?.GetChannelBinding(kind); - #region Synchronous methods public virtual void AuthenticateAsClient(string targetHost) { @@ -402,7 +312,7 @@ public virtual void AuthenticateAsClient(string targetHost, X509CertificateColle ClientCertificates = clientCertificates, EnabledSslProtocols = enabledSslProtocols, CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - EncryptionPolicy = _encryptionPolicy, + EncryptionPolicy = _sslAuthenticationOptions.EncryptionPolicy, }; AuthenticateAsClient(options); @@ -411,11 +321,11 @@ public virtual void AuthenticateAsClient(string targetHost, X509CertificateColle public void AuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions) { ArgumentNullException.ThrowIfNull(sslClientAuthenticationOptions); + ArgumentNullException.ThrowIfNull(sslClientAuthenticationOptions.TargetHost, nameof(sslClientAuthenticationOptions.TargetHost)); - SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback); - SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); + ThrowIfExceptional(); - ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate); + _sslAuthenticationOptions.UpdateOptions(sslClientAuthenticationOptions); ProcessAuthenticationAsync().GetAwaiter().GetResult(); } @@ -437,7 +347,7 @@ public virtual void AuthenticateAsServer(X509Certificate serverCertificate, bool ClientCertificateRequired = clientCertificateRequired, EnabledSslProtocols = enabledSslProtocols, CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - EncryptionPolicy = _encryptionPolicy, + EncryptionPolicy = _sslAuthenticationOptions.EncryptionPolicy, }; AuthenticateAsServer(options); @@ -447,9 +357,7 @@ public void AuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthent { ArgumentNullException.ThrowIfNull(sslServerAuthenticationOptions); - SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); - - ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); + _sslAuthenticationOptions.UpdateOptions(sslServerAuthenticationOptions); ProcessAuthenticationAsync().GetAwaiter().GetResult(); } #endregion @@ -467,7 +375,7 @@ public virtual Task AuthenticateAsClientAsync(string targetHost, X509Certificate ClientCertificates = clientCertificates, EnabledSslProtocols = enabledSslProtocols, CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - EncryptionPolicy = _encryptionPolicy, + EncryptionPolicy = _sslAuthenticationOptions.EncryptionPolicy, }; return AuthenticateAsClientAsync(options); @@ -476,22 +384,10 @@ public virtual Task AuthenticateAsClientAsync(string targetHost, X509Certificate public Task AuthenticateAsClientAsync(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(sslClientAuthenticationOptions); + ArgumentNullException.ThrowIfNull(sslClientAuthenticationOptions.TargetHost, nameof(sslClientAuthenticationOptions.TargetHost)); - SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback); - SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); - - ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate); - - return ProcessAuthenticationAsync(isAsync: true, cancellationToken); - } - - private Task AuthenticateAsClientApm(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default) - { - SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback); - SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); - - ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate); - + ThrowIfExceptional(); + _sslAuthenticationOptions.UpdateOptions(sslClientAuthenticationOptions); return ProcessAuthenticationAsync(isAsync: true, cancellationToken); } @@ -505,7 +401,7 @@ public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, ServerCertificate = serverCertificate, ClientCertificateRequired = clientCertificateRequired, CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - EncryptionPolicy = _encryptionPolicy, + EncryptionPolicy = _sslAuthenticationOptions.EncryptionPolicy, }; return AuthenticateAsServerAsync(options); @@ -519,7 +415,7 @@ public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, ClientCertificateRequired = clientCertificateRequired, EnabledSslProtocols = enabledSslProtocols, CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - EncryptionPolicy = _encryptionPolicy, + EncryptionPolicy = _sslAuthenticationOptions.EncryptionPolicy, }; return AuthenticateAsServerAsync(options); @@ -528,24 +424,14 @@ public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(sslServerAuthenticationOptions); - - SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); - ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - - return ProcessAuthenticationAsync(isAsync: true, cancellationToken); - } - - private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default) - { - SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); - ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - + _sslAuthenticationOptions.UpdateOptions(sslServerAuthenticationOptions); return ProcessAuthenticationAsync(isAsync: true, cancellationToken); } public Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, CancellationToken cancellationToken = default) { - ValidateCreateContext(new SslAuthenticationOptions(optionsCallback, state, _userCertificateValidationCallback)); + _sslAuthenticationOptions.UpdateOptions(optionsCallback, state); + return ProcessAuthenticationAsync(isAsync: true, cancellationToken); } @@ -553,13 +439,13 @@ public virtual Task ShutdownAsync() { ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); - ProtocolToken message = _context!.CreateShutdownToken()!; + ProtocolToken message = CreateShutdownToken()!; _shutdown = true; return InnerStream.WriteAsync(message.Payload, default).AsTask(); } #endregion - public override bool IsAuthenticated => _context != null && _context.IsValidContext && _exception == null && _handshakeCompleted; + public override bool IsAuthenticated => IsValidContext && _exception == null && _handshakeCompleted; public override bool IsMutuallyAuthenticated { @@ -567,8 +453,8 @@ public override bool IsMutuallyAuthenticated { return IsAuthenticated && - (_context!.IsServer ? _context.LocalServerCertificate : _context.LocalClientCertificate) != null && - _context.IsRemoteCertificateAvailable; /* does not work: Context.IsMutualAuthFlag;*/ + (IsServer ? LocalServerCertificate : LocalClientCertificate) != null && + IsRemoteCertificateAvailable; /* does not work: Context.IsMutualAuthFlag;*/ } } @@ -576,7 +462,7 @@ public override bool IsMutuallyAuthenticated public override bool IsSigned => IsAuthenticated; - public override bool IsServer => _context != null && _context.IsServer; + public override bool IsServer => _sslAuthenticationOptions.IsServer; public virtual SslProtocols SslProtocol { @@ -590,13 +476,12 @@ public virtual SslProtocols SslProtocol // Skips the ThrowIfExceptionalOrNotHandshake() check private SslProtocols GetSslProtocolInternal() { - SslConnectionInfo? info = _context!.ConnectionInfo; - if (info == null) + if (_connectionInfo.Protocol == 0) { return SslProtocols.None; } - SslProtocols proto = (SslProtocols)info.Protocol; + SslProtocols proto = (SslProtocols)_connectionInfo.Protocol; SslProtocols ret = SslProtocols.None; #pragma warning disable 0618 // Ssl2, Ssl3 are deprecated. @@ -637,7 +522,7 @@ private SslProtocols GetSslProtocolInternal() return ret; } - public virtual bool CheckCertRevocationStatus => _context != null && _context.CheckCertRevocationStatus != X509RevocationMode.NoCheck; + public virtual bool CheckCertRevocationStatus => _sslAuthenticationOptions.CertificateRevocationCheckMode != X509RevocationMode.NoCheck; // // This will return selected local cert for both client/server streams @@ -647,16 +532,27 @@ private SslProtocols GetSslProtocolInternal() get { ThrowIfExceptionalOrNotAuthenticated(); - return _context!.IsServer ? _context.LocalServerCertificate : _context.LocalClientCertificate; + return IsServer ? LocalServerCertificate : LocalClientCertificate; } } + public virtual X509Certificate? RemoteCertificate { get { ThrowIfExceptionalOrNotAuthenticated(); - return _context?.RemoteCertificate; + _remoteCertificateExposed = true; + return _remoteCertificate; + } + } + + public SslApplicationProtocol NegotiatedApplicationProtocol + { + get + { + ThrowIfExceptionalOrNotHandshake(); + return _connectionInfo.ApplicationProtocol != null ? new SslApplicationProtocol(_connectionInfo.ApplicationProtocol, false) : default; } } @@ -666,7 +562,7 @@ public virtual TlsCipherSuite NegotiatedCipherSuite get { ThrowIfExceptionalOrNotHandshake(); - return _context!.ConnectionInfo?.TlsCipherSuite ?? default(TlsCipherSuite); + return _connectionInfo.TlsCipherSuite; } } @@ -675,12 +571,7 @@ public virtual CipherAlgorithmType CipherAlgorithm get { ThrowIfExceptionalOrNotHandshake(); - SslConnectionInfo? info = _context!.ConnectionInfo; - if (info == null) - { - return CipherAlgorithmType.None; - } - return (CipherAlgorithmType)info.DataCipherAlg; + return (CipherAlgorithmType)_connectionInfo.DataCipherAlg; } } @@ -689,13 +580,7 @@ public virtual int CipherStrength get { ThrowIfExceptionalOrNotHandshake(); - SslConnectionInfo? info = _context!.ConnectionInfo; - if (info == null) - { - return 0; - } - - return info.DataKeySize; + return _connectionInfo.DataKeySize; } } @@ -704,12 +589,7 @@ public virtual HashAlgorithmType HashAlgorithm get { ThrowIfExceptionalOrNotHandshake(); - SslConnectionInfo? info = _context!.ConnectionInfo; - if (info == null) - { - return (HashAlgorithmType)0; - } - return (HashAlgorithmType)info.DataHashAlg; + return (HashAlgorithmType)_connectionInfo.DataHashAlg; } } @@ -718,13 +598,7 @@ public virtual int HashStrength get { ThrowIfExceptionalOrNotHandshake(); - SslConnectionInfo? info = _context!.ConnectionInfo; - if (info == null) - { - return 0; - } - - return info.DataHashKeySize; + return _connectionInfo.DataHashKeySize; } } @@ -733,13 +607,7 @@ public virtual ExchangeAlgorithmType KeyExchangeAlgorithm get { ThrowIfExceptionalOrNotHandshake(); - SslConnectionInfo? info = _context!.ConnectionInfo; - if (info == null) - { - return (ExchangeAlgorithmType)0; - } - - return (ExchangeAlgorithmType)info.KeyExchangeAlg; + return (ExchangeAlgorithmType)_connectionInfo.KeyExchangeAlg; } } @@ -748,13 +616,7 @@ public virtual int KeyExchangeStrength get { ThrowIfExceptionalOrNotHandshake(); - SslConnectionInfo? info = _context!.ConnectionInfo; - if (info == null) - { - return 0; - } - - return info.KeyExchKeySize; + return _connectionInfo.KeyExchKeySize; } } @@ -986,7 +848,7 @@ private void ThrowIfExceptionalOrNotHandshake() { ThrowIfExceptional(); - if (!IsAuthenticated && _context?.ConnectionInfo == null) + if (!IsAuthenticated) { ThrowNotAuthenticated(); } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs index 54187a60c6c29..bbe3f1a43557b 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs @@ -25,7 +25,6 @@ public static void VerifyPackageInfo() } public static SecurityStatusPal AcceptSecurityContext( - SecureChannel secureChannel, ref SafeFreeCredentials credential, ref SafeDeleteSslContext? context, ReadOnlySpan inputBuffer, @@ -36,19 +35,18 @@ public static void VerifyPackageInfo() } public static SecurityStatusPal InitializeSecurityContext( - SecureChannel secureChannel, ref SafeFreeCredentials credential, ref SafeDeleteSslContext? context, string? targetName, ReadOnlySpan inputBuffer, ref byte[]? outputBuffer, - SslAuthenticationOptions sslAuthenticationOptions) + SslAuthenticationOptions sslAuthenticationOptions, + SelectClientCertificate? clientCertificateSelectionCallback) { return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); } public static SecurityStatusPal Renegotiate( - SecureChannel secureChannel, ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? context, SslAuthenticationOptions sslAuthenticationOptions, @@ -66,14 +64,6 @@ public static void VerifyPackageInfo() return new SafeFreeSslCredentials(certificateContext, protocols, policy); } - internal static byte[]? GetNegotiatedApplicationProtocol(SafeDeleteSslContext? context) - { - if (context == null) - return null; - - return Interop.AndroidCrypto.SSLStreamGetApplicationProtocol(context.SslContext); - } - public static SecurityStatusPal EncryptMessage( SafeDeleteSslContext securityContext, ReadOnlyMemory input, @@ -176,9 +166,9 @@ public static void VerifyPackageInfo() public static void QueryContextConnectionInfo( SafeDeleteSslContext securityContext, - out SslConnectionInfo connectionInfo) + ref SslConnectionInfo connectionInfo) { - connectionInfo = new SslConnectionInfo(securityContext.SslContext); + connectionInfo.UpdateSslConnectionInfo(securityContext.SslContext); } private static SecurityStatusPal HandshakeInternal( diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs index 7fbea273f95ef..73c80b20dab48 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs @@ -32,30 +32,28 @@ public static void VerifyPackageInfo() } public static SecurityStatusPal AcceptSecurityContext( - SecureChannel secureChannel, ref SafeFreeCredentials credential, ref SafeDeleteSslContext? context, ReadOnlySpan inputBuffer, ref byte[]? outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(secureChannel, credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions, null); } public static SecurityStatusPal InitializeSecurityContext( - SecureChannel secureChannel, ref SafeFreeCredentials credential, ref SafeDeleteSslContext? context, string? targetName, ReadOnlySpan inputBuffer, ref byte[]? outputBuffer, - SslAuthenticationOptions sslAuthenticationOptions) + SslAuthenticationOptions sslAuthenticationOptions, + SelectClientCertificate clientCertificateSelectionCallback) { - return HandshakeInternal(secureChannel, credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions, clientCertificateSelectionCallback); } public static SecurityStatusPal Renegotiate( - SecureChannel secureChannel, ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? context, SslAuthenticationOptions sslAuthenticationOptions, @@ -73,14 +71,6 @@ public static void VerifyPackageInfo() return new SafeFreeSslCredentials(certificateContext, protocols, policy); } - internal static byte[]? GetNegotiatedApplicationProtocol(SafeDeleteSslContext? context) - { - if (context == null) - return null; - - return Interop.AppleCrypto.SslGetAlpnSelected(context.SslContext); - } - public static SecurityStatusPal EncryptMessage( SafeDeleteSslContext securityContext, ReadOnlyMemory input, @@ -225,18 +215,18 @@ public static void VerifyPackageInfo() public static void QueryContextConnectionInfo( SafeDeleteSslContext securityContext, - out SslConnectionInfo connectionInfo) + ref SslConnectionInfo connectionInfo) { - connectionInfo = new SslConnectionInfo(securityContext.SslContext); + connectionInfo.UpdateSslConnectionInfo(securityContext.SslContext); } private static SecurityStatusPal HandshakeInternal( - SecureChannel secureChannel, SafeFreeCredentials credential, ref SafeDeleteSslContext? context, ReadOnlySpan inputBuffer, ref byte[]? outputBuffer, - SslAuthenticationOptions sslAuthenticationOptions) + SslAuthenticationOptions sslAuthenticationOptions, + SelectClientCertificate? clientCertificateSelectionCallback) { Debug.Assert(!credential.IsInvalid); @@ -257,9 +247,9 @@ public static void VerifyPackageInfo() SafeSslHandle sslHandle = sslContext!.SslContext; SecurityStatusPal status = PerformHandshake(sslHandle); - if (status.ErrorCode == SecurityStatusPalErrorCode.CredentialsNeeded) + if (status.ErrorCode == SecurityStatusPalErrorCode.CredentialsNeeded && clientCertificateSelectionCallback != null) { - X509Certificate2? clientCertificate = secureChannel.SelectClientCertificate(out _); + X509Certificate2? clientCertificate = clientCertificateSelectionCallback(out bool _); if (clientCertificate != null) { sslAuthenticationOptions.CertificateContext = SslStreamCertificateContext.Create(clientCertificate); diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs index a9192681444be..32ac2bfdbeba5 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs @@ -25,26 +25,25 @@ public static void VerifyPackageInfo() } public static SecurityStatusPal AcceptSecurityContext( - SecureChannel secureChannel, ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, ReadOnlySpan inputBuffer, ref byte[]? outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(secureChannel, credential!, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential!, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions, null); } public static SecurityStatusPal InitializeSecurityContext( - SecureChannel secureChannel, ref SafeFreeCredentials? credential, ref SafeDeleteSslContext? context, string? targetName, ReadOnlySpan inputBuffer, ref byte[]? outputBuffer, - SslAuthenticationOptions sslAuthenticationOptions) + SslAuthenticationOptions sslAuthenticationOptions, + SelectClientCertificate? clientCertificateSelectionCallback) { - return HandshakeInternal(secureChannel, credential!, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential!, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions, clientCertificateSelectionCallback); } public static SafeFreeCredentials AcquireCredentialsHandle(SslStreamCertificateContext? certificateContext, @@ -128,7 +127,6 @@ Interop.Ssl.SslErrorCode.SSL_ERROR_NONE or } public static SecurityStatusPal Renegotiate( - SecureChannel secureChannel, ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? securityContext, SslAuthenticationOptions sslAuthenticationOptions, @@ -142,7 +140,7 @@ Interop.Ssl.SslErrorCode.SSL_ERROR_NONE or { return status; } - return HandshakeInternal(secureChannel, credentialsHandle!, ref securityContext, null, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credentialsHandle!, ref securityContext, null, ref outputBuffer, sslAuthenticationOptions, null); } public static void QueryContextStreamSizes(SafeDeleteContext? securityContext, out StreamSizes streamSizes) @@ -150,13 +148,13 @@ public static void QueryContextStreamSizes(SafeDeleteContext? securityContext, o streamSizes = StreamSizes.Default; } - public static void QueryContextConnectionInfo(SafeDeleteSslContext securityContext, out SslConnectionInfo connectionInfo) + public static void QueryContextConnectionInfo(SafeDeleteSslContext securityContext, ref SslConnectionInfo connectionInfo) { - connectionInfo = new SslConnectionInfo(securityContext.SslContext); + connectionInfo.UpdateSslConnectionInfo(securityContext.SslContext); } - private static SecurityStatusPal HandshakeInternal(SecureChannel secureChannel, SafeFreeCredentials credential, ref SafeDeleteSslContext? context, - ReadOnlySpan inputBuffer, ref byte[]? outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credential, ref SafeDeleteSslContext? context, + ReadOnlySpan inputBuffer, ref byte[]? outputBuffer, SslAuthenticationOptions sslAuthenticationOptions, SelectClientCertificate? clientCertificateSelectionCallback) { Debug.Assert(!credential.IsInvalid); @@ -172,9 +170,9 @@ public static void QueryContextConnectionInfo(SafeDeleteSslContext securityConte SecurityStatusPalErrorCode errorCode = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, inputBuffer, out output, out outputSize); - if (errorCode == SecurityStatusPalErrorCode.CredentialsNeeded) + if (errorCode == SecurityStatusPalErrorCode.CredentialsNeeded && clientCertificateSelectionCallback != null) { - X509Certificate2? clientCertificate = secureChannel.SelectClientCertificate(out _); + X509Certificate2? clientCertificate = clientCertificateSelectionCallback(out _); if (clientCertificate != null) { sslAuthenticationOptions.CertificateContext = SslStreamCertificateContext.Create(clientCertificate); @@ -222,14 +220,6 @@ public static void QueryContextConnectionInfo(SafeDeleteSslContext securityConte } } - internal static byte[]? GetNegotiatedApplicationProtocol(SafeDeleteSslContext? context) - { - if (context == null) - return null; - - return Interop.Ssl.SslGetAlpnSelected(context.SslContext); - } - public static SecurityStatusPal ApplyAlertToken(ref SafeFreeCredentials? credentialsHandle, SafeDeleteContext? securityContext, TlsAlertType alertType, TlsAlertMessage alertMessage) { // There doesn't seem to be an exposed API for writing an alert, diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs index d9c28a4692e7b..6c9286cf5b121 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs @@ -53,7 +53,6 @@ public static byte[] ConvertAlpnProtocolListToByteArray(List inputBuffer, @@ -89,13 +88,13 @@ public static byte[] ConvertAlpnProtocolListToByteArray(List inputBuffer, ref byte[]? outputBuffer, - SslAuthenticationOptions sslAuthenticationOptions) + SslAuthenticationOptions sslAuthenticationOptions, + SelectClientCertificate? clientCertificateSelectionCallback) { Interop.SspiCli.ContextFlags unusedAttributes = default; @@ -126,14 +125,13 @@ public static byte[] ConvertAlpnProtocolListToByteArray(List(); - SecurityStatusPal status = AcceptSecurityContext(secureChannel, ref credentialsHandle, ref context, Span.Empty, ref output, sslAuthenticationOptions); + SecurityStatusPal status = AcceptSecurityContext(ref credentialsHandle, ref context, Span.Empty, ref output, sslAuthenticationOptions); outputBuffer = output; return status; } @@ -305,22 +303,6 @@ public static unsafe SafeFreeCredentials AcquireCredentialsHandleSchCredentials( return AcquireCredentialsHandle(direction, &credential); } - internal static byte[]? GetNegotiatedApplicationProtocol(SafeDeleteContext context) - { - Interop.SecPkgContext_ApplicationProtocol alpnContext = default; - bool success = SSPIWrapper.QueryBlittableContextAttributes(GlobalSSPI.SSPISecureChannel, context, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_APPLICATION_PROTOCOL, ref alpnContext); - - // Check if the context returned is alpn data, with successful negotiation. - if (success && - alpnContext.ProtoNegoExt == Interop.ApplicationProtocolNegotiationExt.ALPN && - alpnContext.ProtoNegoStatus == Interop.ApplicationProtocolNegotiationStatus.Success) - { - return alpnContext.Protocol; - } - - return null; - } - public static unsafe SecurityStatusPal EncryptMessage(SafeDeleteSslContext securityContext, ReadOnlyMemory input, int headerSize, int trailerSize, ref byte[] output, out int resultSize) { // Ensure that there is sufficient space for the message output. @@ -476,26 +458,9 @@ public static void QueryContextStreamSizes(SafeDeleteContext securityContext, ou streamSizes = new StreamSizes(interopStreamSizes); } - public static void QueryContextConnectionInfo(SafeDeleteContext securityContext, out SslConnectionInfo connectionInfo) + public static void QueryContextConnectionInfo(SafeDeleteContext securityContext, ref SslConnectionInfo connectionInfo) { - SecPkgContext_ConnectionInfo interopConnectionInfo = default; - bool success = SSPIWrapper.QueryBlittableContextAttributes( - GlobalSSPI.SSPISecureChannel, - securityContext, - Interop.SspiCli.ContextAttribute.SECPKG_ATTR_CONNECTION_INFO, - ref interopConnectionInfo); - Debug.Assert(success); - - TlsCipherSuite cipherSuite = default; - SecPkgContext_CipherInfo cipherInfo = default; - - success = SSPIWrapper.QueryBlittableContextAttributes(GlobalSSPI.SSPISecureChannel, securityContext, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_CIPHER_INFO, ref cipherInfo); - if (success) - { - cipherSuite = (TlsCipherSuite)cipherInfo.dwCipherSuite; - } - - connectionInfo = new SslConnectionInfo(interopConnectionInfo, cipherSuite); + connectionInfo.UpdateSslConnectionInfo(securityContext); } private static int GetProtocolFlagsFromSslProtocols(SslProtocols protocols, bool isServer) diff --git a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs index 140792db4ec77..a34fcab840369 100644 --- a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs @@ -6,6 +6,10 @@ using System.Threading; using System.Threading.Tasks; +// Disable warning about unused or unasiggned variables +#pragma warning disable CS0649 +#pragma warning disable CS0414 + namespace System.Net.Security { public partial class SslStream @@ -13,11 +17,37 @@ public partial class SslStream private class FakeOptions { public string TargetHost; + public EncryptionPolicy EncryptionPolicy; + public bool IsServer; + public RemoteCertificateValidationCallback? CertValidationDelegate; + public LocalCertificateSelectionCallback? CertSelectionDelegate; + public X509RevocationMode CertificateRevocationCheckMode; + + public void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions) + { + } + + public void UpdateOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions) + { + } + + internal void UpdateOptions(ServerOptionsSelectionCallback optionCallback, object? state) + { + } } - private FakeOptions? _sslAuthenticationOptions; + private FakeOptions _sslAuthenticationOptions = new FakeOptions(); + private SslConnectionInfo _connectionInfo; + internal ChannelBinding? GetChannelBinding(ChannelBindingKind kind) => null; + private bool _remoteCertificateExposed; + private X509Certificate2? LocalClientCertificate; + private X509Certificate2? LocalServerCertificate; + private bool IsRemoteCertificateAvailable; + private bool IsValidContext; + private X509Certificate2? _remoteCertificate; - private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertificateValidationCallback? remoteCallback, LocalCertSelectionCallback? localCallback) + + private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertificateValidationCallback? remoteCallback, LocalCertificateSelectionCallback? localCallback) { // Without setting (or using) these members you will get a build exception in the unit test project. // The code that normally uses these in the main solution is in the implementation of SslStream. @@ -26,7 +56,6 @@ private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthe { } - _context = null; _exception = null; _nestedWrite = 0; _handshakeCompleted = false; @@ -62,25 +91,15 @@ private Task ProcessAuthenticationAsync(bool isAsync = false, CancellationToken private void ReturnReadBufferIfEmpty() { } - } - internal class SecureChannel - { - internal bool IsValidContext => default; - internal bool IsServer => default; - internal SslConnectionInfo ConnectionInfo => default; - internal ChannelBinding GetChannelBinding(ChannelBindingKind kind) => default; - internal X509Certificate LocalServerCertificate => default; - internal X509Certificate RemoteCertificate => default; - internal bool IsRemoteCertificateAvailable => default; - internal SslApplicationProtocol NegotiatedApplicationProtocol => default; - internal X509Certificate LocalClientCertificate => default; - internal X509RevocationMode CheckCertRevocationStatus => default; - internal ProtocolToken CreateShutdownToken() => default; + private ProtocolToken? CreateShutdownToken() + { + return null; + } internal static X509Certificate2? FindCertificateWithPrivateKey(object instance, bool isServer, X509Certificate certificate) { - return certificate as X509Certificate2; + return null; } }