diff --git a/.appveyor.yml b/.appveyor.yml index d86285f..2ee3b53 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -4,8 +4,8 @@ image: Visual Studio 2017 install: - ps: dotnet tool install coveralls.net --tool-path tools build_script: - - cmd: dotnet build src/Castle.Sdk -c Release -f netstandard2.0 - - cmd: dotnet build src/Castle.Sdk -c Release -f net461 + - cmd: dotnet build src/Castle.Sdk + - cmd: dotnet build src/Tests -c Debug test_script: - - ps: dotnet test /p:CollectCoverage=true /p:CoverletOutputFormat=opencover /p:Include="[Castle.Sdk*]*" src/Tests/Tests.csproj + - ps: dotnet test --no-build /p:CollectCoverage=true /p:CoverletOutputFormat=opencover /p:Include="[Castle.Sdk*]*" src/Tests/Tests.csproj - ps: .\tools\csmacnz.coveralls.exe --opencover -i src/Tests/coverage.opencover.xml --useRelativePaths --repoToken $env:COVERALLS_REPO_TOKEN --commitId $env:APPVEYOR_REPO_COMMIT --commitBranch $env:APPVEYOR_REPO_BRANCH --commitAuthor $env:APPVEYOR_REPO_COMMIT_AUTHOR --commitEmail $env:APPVEYOR_REPO_COMMIT_AUTHOR_EMAIL --commitMessage $env:APPVEYOR_REPO_COMMIT_MESSAGE --jobId $env:APPVEYOR_JOB_ID diff --git a/src/Castle.Sdk/Context.cs b/src/Castle.Sdk/Context.cs index 7824560..62bf888 100644 --- a/src/Castle.Sdk/Context.cs +++ b/src/Castle.Sdk/Context.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Collections.Specialized; using System.Linq; using Castle.Messages.Requests; @@ -8,7 +9,7 @@ namespace Castle public static class Context { #if NET461 - public static RequestContext FromHttpRequest(System.Web.HttpRequestBase request) + public static RequestContext FromHttpRequest(System.Web.HttpRequestBase request, string[] ipHeaders = null) { var headers = new Dictionary(); foreach (string key in request.Headers.Keys) @@ -16,34 +17,70 @@ public static RequestContext FromHttpRequest(System.Web.HttpRequestBase request) headers.Add(key, request.Headers[key]); } - var clientId = request.Headers.AllKeys.Contains("X-Castle-Client-ID", StringComparer.OrdinalIgnoreCase) - ? request.Headers["X-Castle-Client-ID"] - : request.Cookies["__cid"]?.Value; + var clientId = GetClientIdForFramework(request.Headers, name => request.Cookies[name]?.Value); + + var ip = GetIpForFramework(request.Headers, ipHeaders, () => request.UserHostAddress); return new RequestContext() { ClientId = clientId, Headers = headers, - Ip = request.UserHostAddress + Ip = ip }; } #endif -#if NETSTANDARD2_0 - public static RequestContext FromHttpRequest(Microsoft.AspNetCore.Http.HttpRequest request) + internal static string GetClientIdForFramework(NameValueCollection headers, Func getCookieValue) { - var clientId = request.Headers.TryGetValue("X-Castle-Client-ID", out var headerId) - ? headerId.First() - : request.Cookies["__cid"]; + return headers.AllKeys.Contains("X-Castle-Client-ID", StringComparer.OrdinalIgnoreCase) + ? headers["X-Castle-Client-ID"] + : getCookieValue("__cid") ?? ""; + } + internal static string GetIpForFramework(NameValueCollection headers, string[] ipHeaders, Func getIpFromHttpContext) + { + foreach (var header in ipHeaders ?? new string[] { }) + { + if (headers.AllKeys.Contains(header, StringComparer.OrdinalIgnoreCase)) + return headers[header]; + } + + return getIpFromHttpContext(); + } + +#if NETSTANDARD2_0 + public static RequestContext FromHttpRequest(Microsoft.AspNetCore.Http.HttpRequest request, string[] ipHeaders = null) + { return new RequestContext() { - ClientId = clientId, + ClientId = GetClientIdForCore(request.Headers, request.Cookies), Headers = request.Headers.ToDictionary(x => x.Key, y => y.Value.FirstOrDefault()), - Ip = request.HttpContext.Connection.RemoteIpAddress.ToString(), + Ip = GetIpForCore(request.Headers, ipHeaders, () => request.HttpContext.Connection.RemoteIpAddress.ToString()) }; } + internal static string GetClientIdForCore( + IDictionary headers, + Microsoft.AspNetCore.Http.IRequestCookieCollection cookies) + { + return headers.TryGetValue("X-Castle-Client-ID", out var headerId) + ? headerId.First() + : cookies["__cid"] ?? ""; + } + + internal static string GetIpForCore( + IDictionary headers, + string[] ipHeaders, + Func getIpFromHttpContext) + { + foreach (var header in ipHeaders ?? new string[] {}) + { + if (headers.TryGetValue(header, out var headerValues)) + return headerValues.First(); + } + + return getIpFromHttpContext(); + } #endif } } diff --git a/src/Tests/Messages/When_creating_request_context_for_Core.cs b/src/Tests/Messages/When_creating_request_context_for_Core.cs new file mode 100644 index 0000000..7285bb0 --- /dev/null +++ b/src/Tests/Messages/When_creating_request_context_for_Core.cs @@ -0,0 +1,137 @@ +using System.Collections.Generic; +using AutoFixture.Xunit2; +using Castle; +using FluentAssertions; +using Microsoft.AspNetCore.Http.Internal; +using Microsoft.Extensions.Primitives; +using Xunit; + +namespace Tests.Messages +{ + public class When_creating_request_context_for_Core + { + [Theory, AutoData] + public void Should_get_client_id_from_castle_header_if_present( + string castleHeaderValue, + string cookieValue) + { + var headers = new Dictionary() + { + ["X-Castle-Client-ID"] = castleHeaderValue, + }; + + var cookies = new RequestCookieCollection(new Dictionary() + { + ["__cid"] = cookieValue + }); + + var result = Context.GetClientIdForCore(headers, cookies); + + result.Should().Be(castleHeaderValue); + } + + [Theory, AutoData] + public void Should_get_client_id_from_cookie_if_castle_header_not_present( + string otherHeader, + string otherHeaderValue, + string cookieValue) + { + var headers = new Dictionary() + { + [otherHeader] = otherHeaderValue + }; + + var cookies = new RequestCookieCollection(new Dictionary() + { + ["__cid"] = cookieValue + }); + + var result = Context.GetClientIdForCore(headers, cookies); + + result.Should().Be(cookieValue); + } + + [Theory, AutoData] + public void Should_use_empty_string_if_unable_to_get_client_id( + string otherHeader, + string otherHeaderValue, + string otherCookie, + string otherCookieValue) + { + var headers = new Dictionary() + { + [otherHeader] = otherHeaderValue + }; + + var cookies = new RequestCookieCollection(new Dictionary() + { + [otherCookie] = otherCookieValue + }); + + var result = Context.GetClientIdForCore(headers, cookies); + + result.Should().Be(""); + } + + [Theory, AutoData] + public void Should_get_ip_from_supplied_headers_in_order( + string ipHeader, + string ip, + string secondaryIpHeader, + string secondaryIp, + string otherHeader, + string otherHeaderValue, + string httpContextIp) + { + var headers = new Dictionary() + { + [ipHeader] = ip, + [secondaryIpHeader] = secondaryIp, + [otherHeader] = otherHeaderValue + }; + + var result = Context.GetIpForCore(headers, new [] { ipHeader, secondaryIpHeader }, () => httpContextIp); + + result.Should().Be(ip); + } + + [Theory, AutoData] + public void Should_get_ip_from_second_header_if_first_is_not_found( + string ipHeader, + string secondaryIpHeader, + string secondaryIp, + string otherHeader, + string otherHeaderValue, + string httpContextIp) + { + var headers = new Dictionary() + { + [secondaryIpHeader] = secondaryIp, + [otherHeader] = otherHeaderValue + }; + + var result = Context.GetIpForCore(headers, new[] { ipHeader, secondaryIpHeader }, () => httpContextIp); + + result.Should().Be(secondaryIp); + } + + [Theory, AutoData] + public void Should_get_ip_from_httpcontext_if_no_header_supplied( + string ipHeader, + string ip, + string otherHeader, + string otherHeaderValue, + string httpContextIp) + { + var headers = new Dictionary() + { + [ipHeader] = ip, + [otherHeader] = otherHeaderValue + }; + + var result = Context.GetIpForCore(headers, null, () => httpContextIp); + + result.Should().Be(httpContextIp); + } + } +} diff --git a/src/Tests/Messages/When_creating_request_context_for_Framework.cs b/src/Tests/Messages/When_creating_request_context_for_Framework.cs new file mode 100644 index 0000000..9502a00 --- /dev/null +++ b/src/Tests/Messages/When_creating_request_context_for_Framework.cs @@ -0,0 +1,126 @@ +using System.Collections.Specialized; +using AutoFixture.Xunit2; +using Castle; +using FluentAssertions; +using Xunit; + +namespace Tests.Messages +{ + public class When_creating_request_context_for_Framework + { + [Theory, AutoData] + public void Should_get_client_id_from_castle_header_if_present( + string castleHeaderValue, + string cookieValue) + { + var headers = new NameValueCollection + { + ["X-Castle-Client-ID"] = castleHeaderValue + }; + + string GetCookie(string name) => name == "__cid" ? cookieValue : null; + + var result = Context.GetClientIdForFramework(headers, GetCookie); + + result.Should().Be(castleHeaderValue); + } + + [Theory, AutoData] + public void Should_get_client_id_from_cookie_if_castle_header_not_present( + string otherHeader, + string otherHeaderValue, + string cookieValue) + { + var headers = new NameValueCollection + { + [otherHeader] = otherHeaderValue + }; + + string GetCookie(string name) => name == "__cid" ? cookieValue : null; + + var result = Context.GetClientIdForFramework(headers, GetCookie); + + result.Should().Be(cookieValue); + } + + [Theory, AutoData] + public void Should_use_empty_string_if_unable_to_get_client_id( + string otherHeader, + string otherHeaderValue, + string otherCookie, + string otherCookieValue) + { + var headers = new NameValueCollection + { + [otherHeader] = otherHeaderValue + }; + + string GetCookie(string name) => name == otherCookie ? otherCookieValue : null; + + var result = Context.GetClientIdForFramework(headers, GetCookie); + + result.Should().Be(""); + } + + [Theory, AutoData] + public void Should_get_ip_from_supplied_headers_in_order( + string ipHeader, + string ip, + string secondaryIpHeader, + string secondaryIp, + string otherHeader, + string otherHeaderValue, + string httpContextIp) + { + var headers = new NameValueCollection + { + [ipHeader] = ip, + [secondaryIpHeader] = secondaryIp, + [otherHeader] = otherHeaderValue + }; + + var result = Context.GetIpForFramework(headers, new [] { ipHeader, secondaryIpHeader }, () => httpContextIp); + + result.Should().Be(ip); + } + + [Theory, AutoData] + public void Should_get_ip_from_second_header_if_first_is_not_found( + string ipHeader, + string secondaryIpHeader, + string secondaryIp, + string otherHeader, + string otherHeaderValue, + string httpContextIp) + { + var headers = new NameValueCollection + { + [secondaryIpHeader] = secondaryIp, + [otherHeader] = otherHeaderValue + }; + + var result = Context.GetIpForFramework(headers, new[] { ipHeader, secondaryIpHeader }, () => httpContextIp); + + result.Should().Be(secondaryIp); + } + + [Theory, AutoData] + public void Should_get_ip_from_httpcontext_if_no_header_supplied( + string ipHeader, + string ip, + string otherHeader, + string otherHeaderValue, + string httpContextIp) + { + var headers = new NameValueCollection + { + [ipHeader] = ip, + [otherHeader] = otherHeaderValue + }; + + var result = Context.GetIpForFramework(headers, null, () => httpContextIp); + + result.Should().Be(httpContextIp); + } + } +}