diff --git a/src/shared/Atlassian.Bitbucket/BitbucketOAuth2Client.cs b/src/shared/Atlassian.Bitbucket/BitbucketOAuth2Client.cs index 9d4f0043e..22c793239 100644 --- a/src/shared/Atlassian.Bitbucket/BitbucketOAuth2Client.cs +++ b/src/shared/Atlassian.Bitbucket/BitbucketOAuth2Client.cs @@ -10,7 +10,7 @@ namespace Atlassian.Bitbucket { public abstract class BitbucketOAuth2Client : OAuth2Client { - public BitbucketOAuth2Client(HttpClient httpClient, OAuth2ServerEndpoints endpoints, string clientId, Uri redirectUri, string clientSecret, ITrace trace) : base(httpClient, endpoints, clientId, redirectUri, clientSecret, trace, false) + public BitbucketOAuth2Client(HttpClient httpClient, OAuth2ServerEndpoints endpoints, string clientId, Uri redirectUri, string clientSecret, ITrace trace) : base(httpClient, endpoints, clientId, redirectUri, clientSecret, false) { } @@ -27,9 +27,9 @@ public string GetRefreshTokenServiceName(InputArguments input) return uri.AbsoluteUri.TrimEnd('/'); } - public Task GetAuthorizationCodeAsync(IOAuth2WebBrowser browser, CancellationToken ct) + public Task GetAuthorizationCodeAsync(IOAuth2WebBrowser browser, CancellationToken ct) { - return GetAuthorizationCodeAsync(Scopes, browser, ct); + return this.GetAuthorizationCodeAsync(Scopes, browser, ct); } protected override bool TryCreateTokenEndpointResult(string json, out OAuth2TokenResult result) diff --git a/src/shared/Core.Tests/Authentication/OAuth2ClientTests.cs b/src/shared/Core.Tests/Authentication/OAuth2ClientTests.cs index bed64f033..ffd4ca730 100644 --- a/src/shared/Core.Tests/Authentication/OAuth2ClientTests.cs +++ b/src/shared/Core.Tests/Authentication/OAuth2ClientTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -37,11 +38,90 @@ public async Task OAuth2Client_GetAuthorizationCodeAsync() OAuth2Client client = CreateClient(httpHandler, endpoints); - OAuth2AuthorizationCodeResult result = await client.GetAuthorizationCodeAsync(expectedScopes, browser, CancellationToken.None); + OAuth2AuthorizationCodeResult result = await client.GetAuthorizationCodeAsync(expectedScopes, browser, null, CancellationToken.None); Assert.Equal(expectedAuthCode, result.Code); } + [Fact] + public async Task OAuth2Client_GetAuthorizationCodeAsync_ExtraQueryParams() + { + const string expectedAuthCode = "68c39cbd8d"; + + var baseUri = new Uri("https://example.com"); + OAuth2ServerEndpoints endpoints = CreateEndpoints(baseUri); + + var httpHandler = new TestHttpMessageHandler {ThrowOnUnexpectedRequest = true}; + + string[] expectedScopes = {"read", "write", "delete"}; + + var extraParams = new Dictionary + { + ["param1"] = "value1", + ["param2"] = "value2", + ["param3"] = "value3" + }; + + OAuth2Application app = CreateTestApplication(); + + var server = new TestOAuth2Server(endpoints); + server.RegisterApplication(app); + server.Bind(httpHandler); + server.TokenGenerator.AuthCodes.Add(expectedAuthCode); + + server.AuthorizationEndpointInvoked += (_, request) => + { + IDictionary actualParams = request.RequestUri.GetQueryParameters(); + foreach (var expected in extraParams) + { + Assert.True(actualParams.TryGetValue(expected.Key, out string actualValue)); + Assert.Equal(expected.Value, actualValue); + } + }; + + IOAuth2WebBrowser browser = new TestOAuth2WebBrowser(httpHandler); + + OAuth2Client client = CreateClient(httpHandler, endpoints); + + OAuth2AuthorizationCodeResult result = await client.GetAuthorizationCodeAsync(expectedScopes, browser, extraParams, CancellationToken.None); + + Assert.Equal(expectedAuthCode, result.Code); + } + + [Fact] + public async Task OAuth2Client_GetAuthorizationCodeAsync_ExtraQueryParams_OverrideStandardArgs_ThrowsException() + { + const string expectedAuthCode = "68c39cbd8d"; + + var baseUri = new Uri("https://example.com"); + OAuth2ServerEndpoints endpoints = CreateEndpoints(baseUri); + + var httpHandler = new TestHttpMessageHandler {ThrowOnUnexpectedRequest = true}; + + string[] expectedScopes = {"read", "write", "delete"}; + + var extraParams = new Dictionary + { + ["param1"] = "value1", + [OAuth2Constants.ClientIdParameter] = "value2", + ["param3"] = "value3" + }; + + OAuth2Application app = CreateTestApplication(); + + var server = new TestOAuth2Server(endpoints); + server.RegisterApplication(app); + server.Bind(httpHandler); + server.TokenGenerator.AuthCodes.Add(expectedAuthCode); + + IOAuth2WebBrowser browser = new TestOAuth2WebBrowser(httpHandler); + + OAuth2Client client = CreateClient(httpHandler, endpoints); + + await Assert.ThrowsAsync(() => + client.GetAuthorizationCodeAsync(expectedScopes, browser, extraParams, CancellationToken.None)); + } + [Fact] public async Task OAuth2Client_GetDeviceCodeAsync() { @@ -217,7 +297,7 @@ public async Task OAuth2Client_E2E_InteractiveWebFlowAndRefresh() OAuth2Client client = CreateClient(httpHandler, endpoints); OAuth2AuthorizationCodeResult authCodeResult = await client.GetAuthorizationCodeAsync( - expectedScopes, browser, CancellationToken.None); + expectedScopes, browser, null, CancellationToken.None); OAuth2TokenResult result1 = await client.GetTokenByAuthorizationCodeAsync(authCodeResult, CancellationToken.None); diff --git a/src/shared/Core/Authentication/OAuth/OAuth2Client.cs b/src/shared/Core/Authentication/OAuth/OAuth2Client.cs index 9558a1f66..949054282 100644 --- a/src/shared/Core/Authentication/OAuth/OAuth2Client.cs +++ b/src/shared/Core/Authentication/OAuth/OAuth2Client.cs @@ -20,9 +20,15 @@ public interface IOAuth2Client /// /// Scopes to request. /// User agent to use to start the authorization code grant flow. + /// Extra parameters to add to the URL query component. /// Token to cancel the operation. /// Authorization code. - Task GetAuthorizationCodeAsync(IEnumerable scopes, IOAuth2WebBrowser browser, CancellationToken ct); + Task GetAuthorizationCodeAsync( + IEnumerable scopes, + IOAuth2WebBrowser browser, + IDictionary extraQueryParams, + CancellationToken ct + ); /// /// Retrieve a device code grant. @@ -65,19 +71,17 @@ public class OAuth2Client : IOAuth2Client private readonly Uri _redirectUri; private readonly string _clientId; private readonly string _clientSecret; - private readonly ITrace _trace; private readonly bool _addAuthHeader; private IOAuth2CodeGenerator _codeGenerator; - public OAuth2Client(HttpClient httpClient, OAuth2ServerEndpoints endpoints, string clientId, Uri redirectUri = null, string clientSecret = null, ITrace trace = null, bool addAuthHeader = true) + public OAuth2Client(HttpClient httpClient, OAuth2ServerEndpoints endpoints, string clientId, Uri redirectUri = null, string clientSecret = null, bool addAuthHeader = true) { _httpClient = httpClient; _endpoints = endpoints; _clientId = clientId; _redirectUri = redirectUri; _clientSecret = clientSecret; - _trace = trace; _addAuthHeader = addAuthHeader; } @@ -87,21 +91,10 @@ public IOAuth2CodeGenerator CodeGenerator set => _codeGenerator = value; } - protected string ClientId => _clientId; - - protected string ClientSecret => _clientSecret; - - protected ITrace Trace => _trace; - - protected OAuth2ServerEndpoints Endpoints => _endpoints; - - protected HttpClient HttpClient => _httpClient; - - protected Uri RedirectUri => _redirectUri; - #region IOAuth2Client - public async Task GetAuthorizationCodeAsync(IEnumerable scopes, IOAuth2WebBrowser browser, CancellationToken ct) + public async Task GetAuthorizationCodeAsync(IEnumerable scopes, + IOAuth2WebBrowser browser, IDictionary extraQueryParams, CancellationToken ct) { string state = CodeGenerator.CreateNonce(); string codeVerifier = CodeGenerator.CreatePkceCodeVerifier(); @@ -118,6 +111,21 @@ public async Task GetAuthorizationCodeAsync(IEnum [OAuth2Constants.AuthorizationEndpoint.PkceChallengeParameter] = codeChallenge }; + if (extraQueryParams?.Count > 0) + { + foreach (var kvp in extraQueryParams) + { + if (queryParams.ContainsKey(kvp.Key)) + { + throw new ArgumentException( + $"Extra query parameter '{kvp.Key}' would override required standard OAuth parameters.", + nameof(extraQueryParams)); + } + + queryParams[kvp.Key] = kvp.Value; + } + } + Uri redirectUri = null; if (_redirectUri != null) { @@ -389,4 +397,13 @@ protected static bool TryDeserializeJson(string json, out T obj) #endregion } + + public static class OAuth2ClientExtensions + { + public static Task GetAuthorizationCodeAsync( + this IOAuth2Client client, IEnumerable scopes, IOAuth2WebBrowser browser, CancellationToken ct) + { + return client.GetAuthorizationCodeAsync(scopes, browser, null, ct); + } + } } diff --git a/src/shared/Core/GenericHostProvider.cs b/src/shared/Core/GenericHostProvider.cs index 2b05537e1..581c78e7f 100644 --- a/src/shared/Core/GenericHostProvider.cs +++ b/src/shared/Core/GenericHostProvider.cs @@ -125,7 +125,6 @@ private async Task GetOAuthAccessToken(Uri remoteUri, string userNa config.ClientId, config.RedirectUri, config.ClientSecret, - Context.Trace, config.UseAuthHeader); // diff --git a/src/shared/GitHub.Tests/GitHubHostProviderTests.cs b/src/shared/GitHub.Tests/GitHubHostProviderTests.cs index 4b4533030..09c42bb0b 100644 --- a/src/shared/GitHub.Tests/GitHubHostProviderTests.cs +++ b/src/shared/GitHub.Tests/GitHubHostProviderTests.cs @@ -196,7 +196,7 @@ public async Task GitHubHostProvider_GenerateCredentialAsync_Browser_ReturnsCred ghAuthMock.Setup(x => x.GetAuthenticationAsync(expectedTargetUri, null, It.IsAny())) .ReturnsAsync(new AuthenticationPromptResult(AuthenticationModes.Browser)); - ghAuthMock.Setup(x => x.GetOAuthTokenViaBrowserAsync(expectedTargetUri, It.IsAny>())) + ghAuthMock.Setup(x => x.GetOAuthTokenViaBrowserAsync(expectedTargetUri, It.IsAny>(), It.IsAny())) .ReturnsAsync(response); var ghApiMock = new Mock(MockBehavior.Strict); @@ -213,7 +213,56 @@ public async Task GitHubHostProvider_GenerateCredentialAsync_Browser_ReturnsCred ghAuthMock.Verify( x => x.GetOAuthTokenViaBrowserAsync( - expectedTargetUri, expectedOAuthScopes), + expectedTargetUri, expectedOAuthScopes, null), + Times.Once); + } + + [Fact] + public async Task GitHubHostProvider_GenerateCredentialAsync_Browser_LoginHint_IncludesHintAndReturnsCredential() + { + var input = new InputArguments(new Dictionary + { + ["protocol"] = "https", + ["host"] = "github.com", + ["username"] = "john.doe" + }); + + var expectedTargetUri = new Uri("https://github.com/"); + IEnumerable expectedOAuthScopes = new[] + { + GitHubConstants.OAuthScopes.Repo, + GitHubConstants.OAuthScopes.Gist, + GitHubConstants.OAuthScopes.Workflow, + }; + + var expectedUserName = "john.doe"; + var tokenValue = "OAUTH-TOKEN"; + var response = new OAuth2TokenResult(tokenValue, "bearer"); + + var context = new TestCommandContext(); + + var ghAuthMock = new Mock(MockBehavior.Strict); + ghAuthMock.Setup(x => x.GetAuthenticationAsync(expectedTargetUri, expectedUserName, It.IsAny())) + .ReturnsAsync(new AuthenticationPromptResult(AuthenticationModes.Browser)); + + ghAuthMock.Setup(x => x.GetOAuthTokenViaBrowserAsync(expectedTargetUri, It.IsAny>(), It.IsAny())) + .ReturnsAsync(response); + + var ghApiMock = new Mock(MockBehavior.Strict); + ghApiMock.Setup(x => x.GetUserInfoAsync(expectedTargetUri, tokenValue)) + .ReturnsAsync(new GitHubUserInfo{Login = expectedUserName}); + + var provider = new GitHubHostProvider(context, ghApiMock.Object, ghAuthMock.Object); + + ICredential credential = await provider.GenerateCredentialAsync(input); + + Assert.NotNull(credential); + Assert.Equal(expectedUserName, credential.Account); + Assert.Equal(tokenValue, credential.Password); + + ghAuthMock.Verify( + x => x.GetOAuthTokenViaBrowserAsync( + expectedTargetUri, expectedOAuthScopes, expectedUserName), Times.Once); } diff --git a/src/shared/GitHub/GitHubAuthentication.cs b/src/shared/GitHub/GitHubAuthentication.cs index 56a696241..bb6c6593e 100644 --- a/src/shared/GitHub/GitHubAuthentication.cs +++ b/src/shared/GitHub/GitHubAuthentication.cs @@ -16,7 +16,7 @@ public interface IGitHubAuthentication : IDisposable Task GetTwoFactorCodeAsync(Uri targetUri, bool isSms); - Task GetOAuthTokenViaBrowserAsync(Uri targetUri, IEnumerable scopes); + Task GetOAuthTokenViaBrowserAsync(Uri targetUri, IEnumerable scopes, string loginHint); Task GetOAuthTokenViaDeviceCodeAsync(Uri targetUri, IEnumerable scopes); } @@ -251,7 +251,7 @@ public async Task GetTwoFactorCodeAsync(Uri targetUri, bool isSms) } } - public async Task GetOAuthTokenViaBrowserAsync(Uri targetUri, IEnumerable scopes) + public async Task GetOAuthTokenViaBrowserAsync(Uri targetUri, IEnumerable scopes, string loginHint) { ThrowIfUserInteractionDisabled(); @@ -270,11 +270,21 @@ public async Task GetOAuthTokenViaBrowserAsync(Uri targetUri, }; var browser = new OAuth2SystemWebBrowser(Context.Environment, browserOptions); + // If we have a login hint we should pass this to GitHub as an extra query parameter + IDictionary queryParams = null; + if (loginHint != null) + { + queryParams = new Dictionary + { + ["login"] = loginHint + }; + } + // Write message to the terminal (if any is attached) for some feedback that we're waiting for a web response Context.Terminal.WriteLine("info: please complete authentication in your browser..."); OAuth2AuthorizationCodeResult authCodeResult = - await oauthClient.GetAuthorizationCodeAsync(scopes, browser, CancellationToken.None); + await oauthClient.GetAuthorizationCodeAsync(scopes, browser, queryParams, CancellationToken.None); return await oauthClient.GetTokenByAuthorizationCodeAsync(authCodeResult, CancellationToken.None); } diff --git a/src/shared/GitHub/GitHubHostProvider.cs b/src/shared/GitHub/GitHubHostProvider.cs index b06c71471..6ee0e410c 100644 --- a/src/shared/GitHub/GitHubHostProvider.cs +++ b/src/shared/GitHub/GitHubHostProvider.cs @@ -150,10 +150,10 @@ public override async Task GenerateCredentialAsync(InputArguments i return patCredential; case AuthenticationModes.Browser: - return await GenerateOAuthCredentialAsync(remoteUri, useBrowser: true); + return await GenerateOAuthCredentialAsync(remoteUri, loginHint: input.UserName, useBrowser: true); case AuthenticationModes.Device: - return await GenerateOAuthCredentialAsync(remoteUri, useBrowser: false); + return await GenerateOAuthCredentialAsync(remoteUri, loginHint: input.UserName, useBrowser: false); case AuthenticationModes.Pat: // The token returned by the user should be good to use directly as the password for Git @@ -176,10 +176,10 @@ public override async Task GenerateCredentialAsync(InputArguments i } } - private async Task GenerateOAuthCredentialAsync(Uri targetUri, bool useBrowser) + private async Task GenerateOAuthCredentialAsync(Uri targetUri, string loginHint, bool useBrowser) { OAuth2TokenResult result = useBrowser - ? await _gitHubAuth.GetOAuthTokenViaBrowserAsync(targetUri, GitHubOAuthScopes) + ? await _gitHubAuth.GetOAuthTokenViaBrowserAsync(targetUri, GitHubOAuthScopes, loginHint) : await _gitHubAuth.GetOAuthTokenViaDeviceCodeAsync(targetUri, GitHubOAuthScopes); // Resolve the GitHub user handle diff --git a/src/shared/TestInfrastructure/Objects/TestOAuth2Server.cs b/src/shared/TestInfrastructure/Objects/TestOAuth2Server.cs index 242ba533c..09ab9fcc6 100644 --- a/src/shared/TestInfrastructure/Objects/TestOAuth2Server.cs +++ b/src/shared/TestInfrastructure/Objects/TestOAuth2Server.cs @@ -27,6 +27,10 @@ public TestOAuth2Server(OAuth2ServerEndpoints endpoints) public TestOAuth2ServerTokenGenerator TokenGenerator = new TestOAuth2ServerTokenGenerator(); + public event EventHandler AuthorizationEndpointInvoked; + public event EventHandler DeviceAuthorizationEndpointInvoked; + public event EventHandler TokenEndpointInvoked; + public void RegisterApplication(OAuth2Application application) { _apps[application.Id] = application; @@ -52,6 +56,8 @@ public void SignInDeviceWithUserCode(string userCode) private Task OnAuthorizationEndpointAsync(HttpRequestMessage request) { + AuthorizationEndpointInvoked?.Invoke(this, request); + IDictionary reqQuery = request.RequestUri.GetQueryParameters(); // The only support response type so far is 'code' @@ -128,6 +134,8 @@ private Task OnAuthorizationEndpointAsync(HttpRequestMessag private async Task OnDeviceAuthorizationEndpointAsync(HttpRequestMessage request) { + DeviceAuthorizationEndpointInvoked?.Invoke(this, request); + IDictionary formData = await request.Content.ReadAsFormContentAsync(); // The client/app ID must be specified and must match a known application @@ -162,6 +170,8 @@ private async Task OnDeviceAuthorizationEndpointAsync(HttpR private async Task OnTokenEndpointAsync(HttpRequestMessage request) { + TokenEndpointInvoked?.Invoke(this, request); + IDictionary formData = await request.Content.ReadAsFormContentAsync(); if (!formData.TryGetValue(OAuth2Constants.TokenEndpoint.GrantTypeParameter, out string grantType))