/
AuthorizationHandler.cs
189 lines (161 loc) · 6.36 KB
/
AuthorizationHandler.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT license. See LICENSE file in the project root for full license information.
--*/
using Azure.Core;
using Azure.Identity;
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Devices.HardwareDevCenterManager.Utility;
internal class HttpRetriesExhaustedException : Exception
{
public HttpRetriesExhaustedException(string msg) : base(msg) { }
}
internal class AuthorizationHandler : DelegatingHandler
{
private string _accessToken;
private readonly AuthorizationHandlerCredentials _authCredentials;
private readonly TimeSpan _httpTimeout;
private const int _maxRetries = 10;
/// <summary>
/// Handles OAuth Tokens for HTTP request to Microsoft Hardware Dev Center
/// </summary>
/// <param name="credentials">The set of credentials to use for the token acquisition</param>
/// <param name="httpTimeoutSeconds">Integer value specifying HTTP timeout when making requests to HDC</param>
public AuthorizationHandler(AuthorizationHandlerCredentials credentials, uint httpTimeoutSeconds)
: base(new HttpClientHandler())
{
_accessToken = null;
_authCredentials = credentials;
_httpTimeout = TimeSpan.FromSeconds(httpTimeoutSeconds);
}
/// <summary>
/// Inserts Bearer token into HTTP requests and also does a retry on failed requests since
/// HDC sometimes fails
/// </summary>
/// <param name="request">HTTP Request to send</param>
/// <param name="cancellationToken">CancellationToken in case the request is cancelled</param>
/// <returns>Returns the HttpResponseMessage from the request</returns>
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
int tries = 0;
HttpResponseMessage response = null;
// If there is no valid access token for HDC, get one and then add it to the request
if (_accessToken == null)
{
await ObtainAccessToken();
}
while (tries < _maxRetries)
{
tries++;
// Clone the original request so we have a copy in case of a failure
HttpRequestMessage clonedRequest = await CloneHttpRequestMessageAsync(request);
clonedRequest.Headers.Add("Authorization", "Bearer " + _accessToken);
// Send request
try
{
response = await base.SendAsync(clonedRequest, cancellationToken);
}
catch (HttpRequestException)
{
// HDC request error, wait a bit and try again
Thread.Sleep(2000);
continue;
}
catch (SocketException)
{
// HDC timed out, wait a bit and try again
Thread.Sleep(2000);
continue;
}
catch (TaskCanceledException tcex)
{
if (!tcex.CancellationToken.IsCancellationRequested)
{
// HDC time out, wait a bit and try again
Thread.Sleep(2000);
continue;
}
else
{
throw tcex;
}
}
// If unauthorized, the token likely expired so get a new one and retry
if (response.StatusCode == HttpStatusCode.Unauthorized)
{
await ObtainAccessToken();
continue;
}
else if (response.StatusCode == HttpStatusCode.InternalServerError)
{
// Sometimes HDC returns 500 errors so wait a bit then retry once instead of failing the call.
Thread.Sleep(2000);
continue;
}
break;
}
if (response == null)
{
throw new HttpRetriesExhaustedException("AuthorizationHandler: NULL response, unable to communicate with Hardware Dev Center");
}
return response;
}
private async Task<bool> ObtainAccessToken()
{
bool IsSuccess = false;
string DevCenterTokenUrl = string.Format("https://login.microsoftonline.com/{0}/oauth2/token", _authCredentials.TenantId);
using (HttpClient client = new())
{
client.Timeout = _httpTimeout;
Uri restApi = new(DevCenterTokenUrl);
ClientSecretCredential credential = new(_authCredentials.TenantId, _authCredentials.ClientId, _authCredentials.Key);
AccessToken token = await credential.GetTokenAsync(new TokenRequestContext(scopes: new string[] { "https://manage.devcenter.microsoft.com/.default" }));
if (string.IsNullOrEmpty(token.Token) == false)
{
_accessToken = token.Token;
IsSuccess = true;
}
}
return IsSuccess;
}
//
// https://stackoverflow.com/questions/21467018/how-to-forward-an-httprequestmessage-to-another-server
//
public static async Task<HttpRequestMessage> CloneHttpRequestMessageAsync(HttpRequestMessage request)
{
HttpRequestMessage clone = new(request.Method, request.RequestUri);
// Copy the request's content (via a MemoryStream) into the cloned object
MemoryStream ms = new();
if (request.Content != null)
{
await request.Content.CopyToAsync(ms).ConfigureAwait(false);
ms.Position = 0;
clone.Content = new StreamContent(ms);
// Copy the content headers
if (request.Content.Headers != null)
{
foreach (KeyValuePair<string, IEnumerable<string>> h in request.Content.Headers)
{
clone.Content.Headers.Add(h.Key, h.Value);
}
}
}
clone.Version = request.Version;
foreach (KeyValuePair<string, object> prop in request.Properties)
{
clone.Properties.Add(prop);
}
foreach (KeyValuePair<string, IEnumerable<string>> header in request.Headers)
{
clone.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
return clone;
}
}