Skip to content

Commit

Permalink
Socket: delete unix local endpoint filename on Close (#52103)
Browse files Browse the repository at this point in the history
Fixes #45537
  • Loading branch information
tmds committed May 31, 2021
1 parent 93cf5df commit d3ed5a9
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ public async Task NamedPipeClient_Connects_With_UnixDomainSocketEndPointServer()
}
}

Assert.True(File.Exists(pipeName));
try { File.Delete(pipeName); } catch { }
Assert.False(File.Exists(pipeName));
}
}
}
90 changes: 49 additions & 41 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ public partial class Socket : IDisposable

private SafeSocketHandle _handle;

// _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the
// correct type (IPEndPoint, etc).
// _rightEndPoint is null if the socket has not been bound. Otherwise, it is an EndPoint of the
// correct type (IPEndPoint, etc). The Bind operation sets _rightEndPoint. Other operations must only set
// it when the value is still null.
// This enables tracking the file created by UnixDomainSocketEndPoint when the Socket is bound,
// and to delete that file when the Socket gets disposed.
internal EndPoint? _rightEndPoint;
internal EndPoint? _remoteEndPoint;

Expand Down Expand Up @@ -284,7 +287,7 @@ public int Available
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint = _nonBlockingConnectRightEndPoint;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
}
Expand Down Expand Up @@ -331,7 +334,7 @@ public int Available
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint = _nonBlockingConnectRightEndPoint;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
}
Expand Down Expand Up @@ -438,7 +441,7 @@ public bool Connected
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
_rightEndPoint = _nonBlockingConnectRightEndPoint;
_rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
}
Expand Down Expand Up @@ -799,11 +802,11 @@ private void DoBind(EndPoint endPointSnapshot, Internals.SocketAddress socketAdd
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
}

if (_rightEndPoint == null)
{
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint = endPointSnapshot;
}
// Save a copy of the EndPoint so we can use it for Create().
// For UnixDomainSocketEndPoint, track the file to delete on Dispose.
_rightEndPoint = endPointSnapshot is UnixDomainSocketEndPoint unixEndPoint ?
unixEndPoint.CreateBoundEndPoint() :
endPointSnapshot;
}

// Establishes a connection to a remote system.
Expand Down Expand Up @@ -1357,11 +1360,8 @@ public int SendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags,
if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramSent();
}

if (_rightEndPoint == null)
{
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint = remoteEP;
}
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint ??= remoteEP;

if (NetEventSource.Log.IsEnabled()) NetEventSource.DumpBuffer(this, buffer, offset, size);
return bytesTransferred;
Expand Down Expand Up @@ -1639,11 +1639,8 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla
catch
{
}
if (_rightEndPoint == null)
{
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint = endPointSnapshot;
}
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint ??= endPointSnapshot;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, errorCode);
Expand Down Expand Up @@ -1733,11 +1730,8 @@ public int ReceiveMessageFrom(Span<byte> buffer, ref SocketFlags socketFlags, re
catch
{
}
if (_rightEndPoint == null)
{
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint = endPointSnapshot;
}
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint ??= endPointSnapshot;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, errorCode);
Expand Down Expand Up @@ -1796,11 +1790,8 @@ public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFl
catch
{
}
if (_rightEndPoint == null)
{
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint = endPointSnapshot;
}
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint ??= endPointSnapshot;
}

if (socketException != null)
Expand Down Expand Up @@ -3121,10 +3112,7 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT
e.StartOperationCommon(this, SocketAsyncOperation.SendTo);

EndPoint? oldEndPoint = _rightEndPoint;
if (_rightEndPoint == null)
{
_rightEndPoint = endPointSnapshot;
}
_rightEndPoint ??= endPointSnapshot;

SocketError socketError;
try
Expand All @@ -3133,7 +3121,7 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT
}
catch
{
_rightEndPoint = null;
_rightEndPoint = oldEndPoint;
_localEndPoint = null;
// Clear in-use flag on event args object.
e.Complete();
Expand Down Expand Up @@ -3229,11 +3217,8 @@ private void DoConnect(EndPoint endPointSnapshot, Internals.SocketAddress socket

if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(SocketError.Success);

if (_rightEndPoint == null)
{
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint = endPointSnapshot;
}
// Save a copy of the EndPoint so we can use it for Create().
_rightEndPoint ??= endPointSnapshot;

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}");

Expand Down Expand Up @@ -3361,6 +3346,18 @@ protected virtual void Dispose(bool disposing)
{
}
}

// Delete file of bound UnixDomainSocketEndPoint.
if (_rightEndPoint is UnixDomainSocketEndPoint unixEndPoint &&
unixEndPoint.BoundFileName is not null)
{
try
{
File.Delete(unixEndPoint.BoundFileName);
}
catch
{ }
}
}

// Clean up any cached data
Expand Down Expand Up @@ -3615,9 +3612,20 @@ internal Socket UpdateAcceptSocket(Socket socket, EndPoint remoteEP)
socket._addressFamily = _addressFamily;
socket._socketType = _socketType;
socket._protocolType = _protocolType;
socket._rightEndPoint = _rightEndPoint;
socket._remoteEndPoint = remoteEP;

// If the _rightEndpoint tracks a UnixDomainSocketEndPoint to delete
// then create a new EndPoint.
if (_rightEndPoint is UnixDomainSocketEndPoint unixEndPoint &&
unixEndPoint.BoundFileName is not null)
{
socket._rightEndPoint = unixEndPoint.CreateUnboundEndPoint();
}
else
{
socket._rightEndPoint = _rightEndPoint;
}

// If the listener socket was bound to a wildcard address, then the `accept` system call
// will assign a specific address to the accept socket's local endpoint instead of a
// wildcard address. In that case we should not copy listener's wildcard local endpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Diagnostics;
using System.Text;
using System.IO;

namespace System.Net.Sockets
{
Expand All @@ -14,13 +15,22 @@ public sealed partial class UnixDomainSocketEndPoint : EndPoint
private readonly string _path;
private readonly byte[] _encodedPath;

// Tracks the file Socket should delete on Dispose.
internal string? BoundFileName { get; }

public UnixDomainSocketEndPoint(string path)
: this(path, null)
{ }

private UnixDomainSocketEndPoint(string path, string? boundFileName)
{
if (path == null)
{
throw new ArgumentNullException(nameof(path));
}

BoundFileName = boundFileName;

// Pathname socket addresses should be null-terminated.
// Linux abstract socket addresses start with a zero byte, they must not be null-terminated.
bool isAbstract = IsAbstract(path);
Expand Down Expand Up @@ -120,6 +130,24 @@ public override string ToString()
}
}

internal UnixDomainSocketEndPoint CreateBoundEndPoint()
{
if (IsAbstract(_path))
{
return this;
}
return new UnixDomainSocketEndPoint(_path, Path.GetFullPath(_path));
}

internal UnixDomainSocketEndPoint CreateUnboundEndPoint()
{
if (IsAbstract(_path) || BoundFileName is null)
{
return this;
}
return new UnixDomainSocketEndPoint(_path, null);
}

private static bool IsAbstract(string path) => path.Length > 0 && path[0] == '\0';

private static bool IsAbstract(byte[] encodedPath) => encodedPath.Length > 0 && encodedPath[0] == 0;
Expand Down
Loading

0 comments on commit d3ed5a9

Please sign in to comment.