/
HostMatcherPolicy.cs
474 lines (407 loc) · 15 KB
/
HostMatcherPolicy.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Globalization;
using System.Linq;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Routing.Matching;
/// <summary>
/// A <see cref="MatcherPolicy"/> that implements filtering and selection by
/// the host header of a request.
/// </summary>
public sealed class HostMatcherPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy, IEndpointSelectorPolicy
{
private const string WildcardHost = "*";
private const string WildcardPrefix = "*.";
// Run after HTTP methods, but before 'default'.
/// <inheritdoc />
public override int Order { get; } = -100;
/// <inheritdoc />
public IComparer<Endpoint> Comparer { get; } = new HostMetadataEndpointComparer();
bool INodeBuilderPolicy.AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
{
if (endpoints == null)
{
throw new ArgumentNullException(nameof(endpoints));
}
return !ContainsDynamicEndpoints(endpoints) && AppliesToEndpointsCore(endpoints);
}
bool IEndpointSelectorPolicy.AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
{
// When the node contains dynamic endpoints we can't make any assumptions.
var applies = ContainsDynamicEndpoints(endpoints);
if (applies)
{
// Run for the side-effect of validating metadata.
AppliesToEndpointsCore(endpoints);
}
return applies;
}
private static bool AppliesToEndpointsCore(IReadOnlyList<Endpoint> endpoints)
{
return endpoints.Any(e =>
{
var hosts = e.Metadata.GetMetadata<IHostMetadata>()?.Hosts;
if (hosts == null || hosts.Count == 0)
{
return false;
}
foreach (var host in hosts)
{
// Don't run policy on endpoints that match everything
var key = CreateEdgeKey(host);
if (!key.MatchesAll)
{
return true;
}
}
return false;
});
}
/// <inheritdoc />
public Task ApplyAsync(HttpContext httpContext, CandidateSet candidates)
{
if (httpContext == null)
{
throw new ArgumentNullException(nameof(httpContext));
}
if (candidates == null)
{
throw new ArgumentNullException(nameof(candidates));
}
for (var i = 0; i < candidates.Count; i++)
{
if (!candidates.IsValidCandidate(i))
{
continue;
}
var hosts = candidates[i].Endpoint.Metadata.GetMetadata<IHostMetadata>()?.Hosts;
if (hosts == null || hosts.Count == 0)
{
// Can match any host.
continue;
}
var matched = false;
var (requestHost, requestPort) = GetHostAndPort(httpContext);
for (var j = 0; j < hosts.Count; j++)
{
var host = hosts[j].AsSpan();
var port = ReadOnlySpan<char>.Empty;
// Split into host and port
var pivot = host.IndexOf(':');
if (pivot >= 0)
{
port = host.Slice(pivot + 1);
host = host.Slice(0, pivot);
}
if (host == null || MemoryExtensions.Equals(host, WildcardHost, StringComparison.OrdinalIgnoreCase))
{
// Can match any host
}
else if (
host.StartsWith(WildcardPrefix) &&
// Note that we only slice off the `*`. We want to match the leading `.` also.
MemoryExtensions.EndsWith(requestHost, host.Slice(WildcardHost.Length), StringComparison.OrdinalIgnoreCase))
{
// Matches a suffix wildcard.
}
else if (MemoryExtensions.Equals(requestHost, host, StringComparison.OrdinalIgnoreCase))
{
// Matches exactly
}
else
{
// If we get here then the host doesn't match.
continue;
}
if (MemoryExtensions.Equals(port, WildcardHost, StringComparison.OrdinalIgnoreCase))
{
// Port is a wildcard, we allow any port.
}
else if (port.Length > 0 && (!int.TryParse(port, out var parsed) || parsed != requestPort))
{
// If we get here then the port doesn't match.
continue;
}
matched = true;
break;
}
if (!matched)
{
candidates.SetValidity(i, false);
}
}
return Task.CompletedTask;
}
private static EdgeKey CreateEdgeKey(string host)
{
if (host == null)
{
return EdgeKey.WildcardEdgeKey;
}
var hostParts = host.Split(':');
if (hostParts.Length == 1)
{
if (!string.IsNullOrEmpty(hostParts[0]))
{
return new EdgeKey(hostParts[0], null);
}
}
if (hostParts.Length == 2)
{
if (!string.IsNullOrEmpty(hostParts[0]))
{
if (int.TryParse(hostParts[1], out var port))
{
return new EdgeKey(hostParts[0], port);
}
else if (string.Equals(hostParts[1], WildcardHost, StringComparison.Ordinal))
{
return new EdgeKey(hostParts[0], null);
}
}
}
throw new InvalidOperationException($"Could not parse host: {host}");
}
/// <inheritdoc />
public IReadOnlyList<PolicyNodeEdge> GetEdges(IReadOnlyList<Endpoint> endpoints)
{
if (endpoints == null)
{
throw new ArgumentNullException(nameof(endpoints));
}
// The algorithm here is designed to be preserve the order of the endpoints
// while also being relatively simple. Preserving order is important.
// First, build a dictionary of all of the hosts that are included
// at this node.
//
// For now we're just building up the set of keys. We don't add any endpoints
// to lists now because we don't want ordering problems.
var edges = new Dictionary<EdgeKey, List<Endpoint>>();
for (var i = 0; i < endpoints.Count; i++)
{
var endpoint = endpoints[i];
var hosts = endpoint.Metadata.GetMetadata<IHostMetadata>()?.Hosts.Select(CreateEdgeKey).ToArray();
if (hosts == null || hosts.Length == 0)
{
hosts = new[] { EdgeKey.WildcardEdgeKey };
}
for (var j = 0; j < hosts.Length; j++)
{
var host = hosts[j];
if (!edges.ContainsKey(host))
{
edges.Add(host, new List<Endpoint>());
}
}
}
// Now in a second loop, add endpoints to these lists. We've enumerated all of
// the states, so we want to see which states this endpoint matches.
for (var i = 0; i < endpoints.Count; i++)
{
var endpoint = endpoints[i];
var endpointKeys = endpoint.Metadata.GetMetadata<IHostMetadata>()?.Hosts.Select(CreateEdgeKey).ToArray() ?? Array.Empty<EdgeKey>();
if (endpointKeys.Length == 0)
{
// OK this means that this endpoint matches *all* hosts.
// So, loop and add it to all states.
foreach (var kvp in edges)
{
kvp.Value.Add(endpoint);
}
}
else
{
// OK this endpoint matches specific hosts
foreach (var kvp in edges)
{
// The edgeKey maps to a possible request header value
var edgeKey = kvp.Key;
for (var j = 0; j < endpointKeys.Length; j++)
{
var endpointKey = endpointKeys[j];
if (edgeKey.Equals(endpointKey))
{
kvp.Value.Add(endpoint);
break;
}
else if (edgeKey.HasHostWildcard && endpointKey.HasHostWildcard &&
edgeKey.Port == endpointKey.Port && edgeKey.MatchHost(endpointKey.Host))
{
kvp.Value.Add(endpoint);
break;
}
}
}
}
}
return edges
.Select(kvp => new PolicyNodeEdge(kvp.Key, kvp.Value))
.ToArray();
}
/// <inheritdoc />
public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList<PolicyJumpTableEdge> edges)
{
if (edges == null)
{
throw new ArgumentNullException(nameof(edges));
}
// Since our 'edges' can have wildcards, we do a sort based on how wildcard-ey they
// are then then execute them in linear order.
var ordered = edges
.Select(e => (host: (EdgeKey)e.State, destination: e.Destination))
.OrderBy(e => GetScore(e.host))
.ToArray();
return new HostPolicyJumpTable(exitDestination, ordered);
}
private static int GetScore(in EdgeKey key)
{
// Higher score == lower priority.
if (key.MatchesHost && !key.HasHostWildcard && key.MatchesPort)
{
return 1; // Has host AND port, e.g. www.consoto.com:8080
}
else if (key.MatchesHost && !key.HasHostWildcard)
{
return 2; // Has host, e.g. www.consoto.com
}
else if (key.MatchesHost && key.MatchesPort)
{
return 3; // Has wildcard host AND port, e.g. *.consoto.com:8080
}
else if (key.MatchesHost)
{
return 4; // Has wildcard host, e.g. *.consoto.com
}
else if (key.MatchesPort)
{
return 5; // Has port, e.g. *:8080
}
else
{
return 6; // Has neither, e.g. *:* (or no metadata)
}
}
private static (string host, int? port) GetHostAndPort(HttpContext httpContext)
{
var hostString = httpContext.Request.Host;
if (hostString.Port != null)
{
return (hostString.Host, hostString.Port);
}
else if (string.Equals("https", httpContext.Request.Scheme, StringComparison.OrdinalIgnoreCase))
{
return (hostString.Host, 443);
}
else if (string.Equals("http", httpContext.Request.Scheme, StringComparison.OrdinalIgnoreCase))
{
return (hostString.Host, 80);
}
else
{
return (hostString.Host, null);
}
}
private sealed class HostMetadataEndpointComparer : EndpointMetadataComparer<IHostMetadata>
{
protected override int CompareMetadata(IHostMetadata? x, IHostMetadata? y)
{
// Ignore the metadata if it has an empty list of hosts.
return base.CompareMetadata(
x?.Hosts.Count > 0 ? x : null,
y?.Hosts.Count > 0 ? y : null);
}
}
private sealed class HostPolicyJumpTable : PolicyJumpTable
{
private readonly (EdgeKey host, int destination)[] _destinations;
private readonly int _exitDestination;
public HostPolicyJumpTable(int exitDestination, (EdgeKey host, int destination)[] destinations)
{
_exitDestination = exitDestination;
_destinations = destinations;
}
public override int GetDestination(HttpContext httpContext)
{
// HostString can allocate when accessing the host or port
// Store host and port locally and reuse
var (host, port) = GetHostAndPort(httpContext);
var destinations = _destinations;
for (var i = 0; i < destinations.Length; i++)
{
var destination = destinations[i];
if ((!destination.host.MatchesPort || destination.host.Port == port) &&
destination.host.MatchHost(host))
{
return destination.destination;
}
}
return _exitDestination;
}
}
private readonly struct EdgeKey : IEquatable<EdgeKey>, IComparable<EdgeKey>, IComparable
{
internal static readonly EdgeKey WildcardEdgeKey = new EdgeKey(null, null);
public readonly int? Port;
public readonly string Host;
private readonly string? _wildcardEndsWith;
public EdgeKey(string? host, int? port)
{
Host = host ?? WildcardHost;
Port = port;
HasHostWildcard = Host.StartsWith(WildcardPrefix, StringComparison.Ordinal);
_wildcardEndsWith = HasHostWildcard ? Host.Substring(1) : null;
}
public bool HasHostWildcard { get; }
public bool MatchesHost => !string.Equals(Host, WildcardHost, StringComparison.Ordinal);
public bool MatchesPort => Port != null;
public bool MatchesAll => !MatchesHost && !MatchesPort;
public int CompareTo(EdgeKey other)
{
var result = Comparer<string>.Default.Compare(Host, other.Host);
if (result != 0)
{
return result;
}
return Comparer<int?>.Default.Compare(Port, other.Port);
}
public int CompareTo(object? obj)
{
return CompareTo((EdgeKey)obj!);
}
public bool Equals(EdgeKey other)
{
return string.Equals(Host, other.Host, StringComparison.Ordinal) && Port == other.Port;
}
public bool MatchHost(string host)
{
if (MatchesHost)
{
if (HasHostWildcard)
{
return host.EndsWith(_wildcardEndsWith!, StringComparison.OrdinalIgnoreCase);
}
else
{
return string.Equals(host, Host, StringComparison.OrdinalIgnoreCase);
}
}
return true;
}
public override int GetHashCode()
{
return (Host?.GetHashCode() ?? 0) ^ (Port?.GetHashCode() ?? 0);
}
public override bool Equals(object? obj)
{
if (obj is EdgeKey key)
{
return Equals(key);
}
return false;
}
public override string ToString()
{
return $"{Host}:{Port?.ToString(CultureInfo.InvariantCulture) ?? WildcardHost}";
}
}
}