Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportFactory.Bin
~Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportFactory.SocketTransportFactory(Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions!>! options, Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory) -> void
static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(this Microsoft.AspNetCore.Hosting.IWebHostBuilder! hostBuilder) -> Microsoft.AspNetCore.Hosting.IWebHostBuilder!
static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(this Microsoft.AspNetCore.Hosting.IWebHostBuilder! hostBuilder, System.Action<Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions!>! configureOptions) -> Microsoft.AspNetCore.Hosting.IWebHostBuilder!
static Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateDefaultBoundListenSocket(System.Net.EndPoint! endpoint) -> System.Net.Sockets.Socket!
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.get -> System.Func<System.Net.EndPoint!, System.Net.Sockets.Socket!>!
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.set -> void
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Buffers;
using System.ComponentModel;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net;
Expand All @@ -23,7 +24,6 @@ internal sealed class SocketConnectionListener : IConnectionListener
private Socket? _listenSocket;
private int _settingsIndex;
private readonly SocketTransportOptions _options;
private SafeSocketHandle? _socketHandle;

public EndPoint EndPoint { get; private set; }

Expand Down Expand Up @@ -92,43 +92,13 @@ internal void Bind()
}

Socket listenSocket;

switch (EndPoint)
try
{
case FileHandleEndPoint fileHandle:
_socketHandle = new SafeSocketHandle((IntPtr)fileHandle.FileHandle, ownsHandle: true);
listenSocket = new Socket(_socketHandle);
break;
case UnixDomainSocketEndPoint unix:
listenSocket = new Socket(unix.AddressFamily, SocketType.Stream, ProtocolType.Unspecified);
BindSocket();
break;
case IPEndPoint ip:
listenSocket = new Socket(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp);

// Kestrel expects IPv6Any to bind to both IPv6 and IPv4
if (ip.Address == IPAddress.IPv6Any)
{
listenSocket.DualMode = true;
}
BindSocket();
break;
default:
listenSocket = new Socket(EndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
BindSocket();
break;
listenSocket = _options.CreateBoundListenSocket(EndPoint);
}

void BindSocket()
catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse)
{
try
{
listenSocket.Bind(EndPoint);
}
catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse)
{
throw new AddressInUseException(e.Message, e);
}
throw new AddressInUseException(e.Message, e);
}

Debug.Assert(listenSocket.LocalEndPoint != null);
Expand Down Expand Up @@ -193,17 +163,13 @@ void BindSocket()
public ValueTask UnbindAsync(CancellationToken cancellationToken = default)
{
_listenSocket?.Dispose();

_socketHandle?.Dispose();
return default;
}

public ValueTask DisposeAsync()
{
_listenSocket?.Dispose();

_socketHandle?.Dispose();

// Dispose the memory pool
_memoryPool.Dispose();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

using System;
using System.Buffers;
using System.Net;
using System.Net.Sockets;
using Microsoft.AspNetCore.Connections;

namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets
{
Expand Down Expand Up @@ -65,6 +68,78 @@ public class SocketTransportOptions
/// </remarks>
public bool UnsafePreferInlineScheduling { get; set; }

/// <summary>
/// A function used to create a new <see cref="Socket"/> to listen with. If
/// not set, <see cref="CreateDefaultBoundListenSocket" /> is used.
/// </summary>
/// <remarks>
/// Implementors are expected to call <see cref="Socket.Bind"/> on the
/// <see cref="Socket"/>. Please note that <see cref="CreateDefaultBoundListenSocket"/>
/// calls <see cref="Socket.Bind"/> as part of its implementation, so implementors
/// using this method do not need to call it again.
/// </remarks>
public Func<EndPoint, Socket> CreateBoundListenSocket { get; set; } = CreateDefaultBoundListenSocket;

/// <summary>
/// Creates a default instance of <see cref="Socket"/> for the given <see cref="EndPoint"/>
/// that can be used by a connection listener to listen for inbound requests. <see cref="Socket.Bind"/>
/// is called by this method.
/// </summary>
/// <param name="endpoint">
/// An <see cref="EndPoint"/>.
/// </param>
/// <returns>
/// A <see cref="Socket"/> instance.
/// </returns>
public static Socket CreateDefaultBoundListenSocket(EndPoint endpoint)
{
Socket listenSocket;
switch (endpoint)
{
case FileHandleEndPoint fileHandle:
// We're passing "ownsHandle: true" here even though we don't necessarily
// own the handle because Socket.Dispose will clean-up everything safely.
// If the handle was already closed or disposed then the socket will
// be torn down gracefully, and if the caller never cleans up their handle
// then we'll do it for them.
//
// If we don't do this then we run the risk of Kestrel hanging because the
// the underlying socket is never closed and the transport manager can hang
// when it attempts to stop.
listenSocket = new Socket(
new SafeSocketHandle((IntPtr)fileHandle.FileHandle, ownsHandle: true)
);
break;
case UnixDomainSocketEndPoint unix:
listenSocket = new Socket(unix.AddressFamily, SocketType.Stream, ProtocolType.Unspecified);
break;
case IPEndPoint ip:
listenSocket = new Socket(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp);

// Kestrel expects IPv6Any to bind to both IPv6 and IPv4
if (ip.Address == IPAddress.IPv6Any)
{
listenSocket.DualMode = true;
}

break;
default:
listenSocket = new Socket(endpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
break;
}

// we only call Bind on sockets that were _not_ created
// using a file handle; the handle is already bound
// to an underlying socket so doing it again causes the
// underlying PAL call to throw
if (!(endpoint is FileHandleEndPoint))
{
listenSocket.Bind(endpoint);
}

return listenSocket;
}

internal Func<MemoryPool<byte>> MemoryPoolFactory { get; set; } = System.Buffers.PinnedBlockMemoryPoolFactory.Create;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel.FunctionalTests;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Hosting;
using Xunit;

namespace Sockets.BindTests
{
public class SocketTransportOptionsTests : LoggedTestBase
{
[Theory]
[MemberData(nameof(GetEndpoints))]
public async Task SocketTransportCallsCreateBoundListenSocket(EndPoint endpointToTest)
{
var wasCalled = false;

Socket CreateListenSocket(EndPoint endpoint)
{
wasCalled = true;
return SocketTransportOptions.CreateDefaultBoundListenSocket(endpoint);
}

using var host = CreateWebHost(
endpointToTest,
options =>
{
options.CreateBoundListenSocket = CreateListenSocket;
}
);

await host.StartAsync();
Assert.True(wasCalled, $"Expected {nameof(SocketTransportOptions.CreateBoundListenSocket)} to be called.");
await host.StopAsync();
}

[Theory]
[MemberData(nameof(GetEndpoints))]
public void CreateDefaultBoundListenSocket_BindsForAllEndPoints(EndPoint endpoint)
{
using var listenSocket = SocketTransportOptions.CreateDefaultBoundListenSocket(endpoint);
Assert.NotNull(listenSocket.LocalEndPoint);
}

// static to ensure that the underlying handle doesn't get disposed
// when a local reference is GCed by the iterator in GetEndPoints
private static Socket _fileHandleSocket;

public static IEnumerable<object[]> GetEndpoints()
{
// IPv4
yield return new object[] {new IPEndPoint(IPAddress.Loopback, 0)};
// IPv6
yield return new object[] {new IPEndPoint(IPAddress.IPv6Loopback, 0)};
// Unix sockets
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
yield return new object[]
{
new UnixDomainSocketEndPoint($"/tmp/{DateTime.UtcNow:yyyyMMddTHHmmss.fff}.sock")
};
}

// file handle
// slightly messy but allows us to create a FileHandleEndPoint
// from the underlying OS handle used by the socket
_fileHandleSocket = new(
AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp
);
_fileHandleSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
yield return new object[]
{
new FileHandleEndPoint((ulong) _fileHandleSocket.Handle, FileHandleType.Auto)
};

// TODO: other endpoint types?
}

private IHost CreateWebHost(EndPoint endpoint, Action<SocketTransportOptions> configureSocketOptions) =>
TransportSelector.GetHostBuilder()
.ConfigureWebHost(
webHostBuilder =>
{
webHostBuilder
.UseSockets(configureSocketOptions)
.UseKestrel(options => options.Listen(endpoint))
.Configure(
app => app.Run(ctx => ctx.Response.WriteAsync("Hello World"))
);
}
)
.ConfigureServices(AddTestLogging)
.Build();
}
}