diff --git a/src/Authentication/Authentication/Handlers/RequestHeaderHandler.cs b/src/Authentication/Authentication/Handlers/RequestHeaderHandler.cs index 3e4f73fcb1c..7fcdbe2869a 100644 --- a/src/Authentication/Authentication/Handlers/RequestHeaderHandler.cs +++ b/src/Authentication/Authentication/Handlers/RequestHeaderHandler.cs @@ -2,12 +2,10 @@ // Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. // ------------------------------------------------------------------------------ -using Microsoft.Graph.PowerShell.Authentication.Cmdlets; using System; using System.Collections.Generic; using System.Linq; using System.Net.Http; -using System.Net.Http.Headers; using System.Reflection; using System.Threading; using System.Threading.Tasks; @@ -21,37 +19,30 @@ namespace Microsoft.Graph.PowerShell.Authentication.Handlers internal class RequestHeaderHandler : DelegatingHandler { /// The version for current assembly. - private static readonly AssemblyName _assemblyInfo = typeof(ConnectMgGraph).GetTypeInfo().Assembly.GetName(); + private static readonly AssemblyName _assemblyInfo = typeof(RequestHeaderHandler).GetTypeInfo().Assembly.GetName(); public RequestHeaderHandler() { } - public RequestHeaderHandler(HttpRequestHeaders requestHeaders, HttpMessageHandler innerHandler) : base(innerHandler) { } - protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - SetRequestHeaders(request); - return base.SendAsync(request, cancellationToken); - } - - private static void SetRequestHeaders(HttpRequestMessage request) - { - string sdkVersionHeaderValue = string.Format(request.RequestUri.AbsolutePath.StartsWith("/beta") ? Constants.PSSDKHeaderValueBeta : Constants.PSSDKHeaderValueV1, _assemblyInfo.Version.Major, _assemblyInfo.Version.Minor, _assemblyInfo.Version.Build); - PrependHeader(request, CoreConstants.Headers.SdkVersionHeaderName, sdkVersionHeaderValue); - } - - private static void PrependHeader(HttpRequestMessage request, string headerName, string headerValue) - { - if (request.Headers.TryGetValues(headerName, out IEnumerable previousSDKHeaders)) + string psSdkVersionHeader = string.Format(request.RequestUri.AbsolutePath.StartsWith("/beta") ? Constants.PSSDKHeaderValueBeta + : Constants.PSSDKHeaderValueV1, _assemblyInfo.Version.Major, _assemblyInfo.Version.Minor, _assemblyInfo.Version.Build); + if (request.Headers.TryGetValues(CoreConstants.Headers.SdkVersionHeaderName, out IEnumerable previousSDKHeaders)) { - request.Headers.Remove(headerName); - request.Headers.Add(headerName, new[] { - headerValue, previousSDKHeaders.Where(h => h.StartsWith(Constants.DotNetSDKHeaderValue, StringComparison.InvariantCultureIgnoreCase)).FirstOrDefault() - }); + var dotNetSdkHeader = previousSDKHeaders.Where(h => h.StartsWith(Constants.DotNetSDKHeaderValue, StringComparison.InvariantCultureIgnoreCase)).FirstOrDefault(); + request.Headers.Remove(CoreConstants.Headers.SdkVersionHeaderName); + request.Headers.Add(CoreConstants.Headers.SdkVersionHeaderName, new[] { psSdkVersionHeader, dotNetSdkHeader }); } else { - request.Headers.Add(headerName, headerValue); + request.Headers.Add(CoreConstants.Headers.SdkVersionHeaderName, psSdkVersionHeader); } + + if (request.Headers.Contains(CoreConstants.Headers.ClientRequestId)) + request.Headers.Remove(CoreConstants.Headers.ClientRequestId); + request.Headers.Add(CoreConstants.Headers.ClientRequestId, Guid.NewGuid().ToString()); + + return base.SendAsync(request, cancellationToken); } } } diff --git a/src/Authentication/Authentication/Helpers/HttpHelpers.cs b/src/Authentication/Authentication/Helpers/HttpHelpers.cs index 699e808a3b5..fa050385c2b 100644 --- a/src/Authentication/Authentication/Helpers/HttpHelpers.cs +++ b/src/Authentication/Authentication/Helpers/HttpHelpers.cs @@ -46,7 +46,6 @@ private static HttpClient GetGraphHttpClient(IAuthenticationProvider authProvide throw new AuthenticationException(string.Format(CultureInfo.InvariantCulture, Core.ErrorConstants.Message.MissingSessionProperty, nameof(requestContext))); IList delegatingHandlers = new List { - new RequestHeaderHandler(), new AuthenticationHandler(authProvider), new NationalCloudHandler(), new ODataQueryOptionsHandler(), @@ -57,7 +56,8 @@ private static HttpClient GetGraphHttpClient(IAuthenticationProvider authProvide MaxRetry = requestContext.MaxRetry, RetriesTimeLimit= requestContext.RetriesTimeLimit }), - new RedirectHandler() + new RedirectHandler(), + new RequestHeaderHandler() // Should always be last. }; HttpClient httpClient = GraphClientFactory.Create(delegatingHandlers); diff --git a/src/Authentication/Authentication/Helpers/RuntimeUtils.cs b/src/Authentication/Authentication/Helpers/RuntimeUtils.cs new file mode 100644 index 00000000000..8b240e32b6b --- /dev/null +++ b/src/Authentication/Authentication/Helpers/RuntimeUtils.cs @@ -0,0 +1,24 @@ +// ------------------------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. +// ------------------------------------------------------------------------------ + +using System; + +namespace Microsoft.Graph.PowerShell.Authentication.Helpers +{ + /// + /// Utility class containing runtime utility methods. + /// + internal static class RuntimeUtils + { + /// + /// Determines if the PSEdition of the current process is Core. + /// + /// when PSEdition is core, else . + internal static bool IsPsCore() + { + var psCoreVersion = new Version(6, 0, 0); + return GraphSession.Instance.AuthContext.PSHostVersion >= psCoreVersion; + } + } +} diff --git a/tools/Custom/HttpMessageLogFormatter.cs b/tools/Custom/HttpMessageLogFormatter.cs index 5ae7ad48efb..6b1b5a4cdb3 100644 --- a/tools/Custom/HttpMessageLogFormatter.cs +++ b/tools/Custom/HttpMessageLogFormatter.cs @@ -7,6 +7,7 @@ namespace NamespacePrefixPlaceholder.PowerShell using Newtonsoft.Json; using System; using System.Collections.Generic; + using System.IO; using System.Linq; using System.Net.Http; using System.Net.Http.Headers; @@ -17,22 +18,54 @@ namespace NamespacePrefixPlaceholder.PowerShell public static class HttpMessageLogFormatter { + internal static async Task CloneAsync(this HttpRequestMessage originalRequest) + { + var newRequest = new HttpRequestMessage(originalRequest.Method, originalRequest.RequestUri); + + // Copy requestClone headers. + foreach (var header in originalRequest.Headers) + newRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); + + // Copy requestClone properties. + foreach (var property in originalRequest.Properties) + newRequest.Properties.Add(property); + + // Set Content if previous requestClone had one. + if (originalRequest.Content != null) + { + // HttpClient doesn't rewind streams and we have to explicitly do so. + await originalRequest.Content.ReadAsStreamAsync().ContinueWith(t => + { + if (t.Result.CanSeek) + t.Result.Seek(0, SeekOrigin.Begin); + + newRequest.Content = new StreamContent(t.Result); + }).ConfigureAwait(false); + + // Copy content headers. + if (originalRequest.Content.Headers != null) + foreach (var contentHeader in originalRequest.Content.Headers) + newRequest.Content.Headers.TryAddWithoutValidation(contentHeader.Key, contentHeader.Value); + } + return newRequest; + } + public static async Task GetHttpRequestLogAsync(HttpRequestMessage request) { if (request == null) return string.Empty; - + var requestClone = await request.CloneAsync().ConfigureAwait(false); string body = string.Empty; try { - body = (request.Content == null) ? string.Empty : FormatString(await request.Content.ReadAsStringAsync()); + body = (requestClone.Content == null) ? string.Empty : FormatString(await requestClone.Content.ReadAsStringAsync()); } catch { } StringBuilder stringBuilder = new StringBuilder(); stringBuilder.AppendLine($"============================ HTTP REQUEST ============================{Environment.NewLine}"); - stringBuilder.AppendLine($"HTTP Method:{Environment.NewLine}{request.Method.ToString()}{Environment.NewLine}"); - stringBuilder.AppendLine($"Absolute Uri:{Environment.NewLine}{request.RequestUri.ToString()}{Environment.NewLine}"); - stringBuilder.AppendLine($"Headers:{Environment.NewLine}{HeadersToString(ConvertHttpHeadersToCollection(request.Headers))}{Environment.NewLine}"); + stringBuilder.AppendLine($"HTTP Method:{Environment.NewLine}{requestClone.Method.ToString()}{Environment.NewLine}"); + stringBuilder.AppendLine($"Absolute Uri:{Environment.NewLine}{requestClone.RequestUri.ToString()}{Environment.NewLine}"); + stringBuilder.AppendLine($"Headers:{Environment.NewLine}{HeadersToString(ConvertHttpHeadersToCollection(requestClone.Headers))}{Environment.NewLine}"); stringBuilder.AppendLine($"Body:{Environment.NewLine}{SanitizeBody(body)}{Environment.NewLine}"); return stringBuilder.ToString(); }