Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit 524561f

Browse files
authored
Fix race condition on Unix pooling canceled Socket operations (#27866)
1 parent 59c2fab commit 524561f

File tree

2 files changed

+64
-24
lines changed

2 files changed

+64
-24
lines changed

src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,6 @@ public void SetWaiting()
208208
Volatile.Write(ref _state, (int)State.Waiting);
209209
}
210210

211-
public void DoCallback()
212-
{
213-
InvokeCallback();
214-
}
215-
216211
public bool TryCancel()
217212
{
218213
Trace("Enter");
@@ -267,8 +262,11 @@ public bool TryCancel()
267262
#if DEBUG
268263
Debug.Assert(Interlocked.CompareExchange(ref _callbackQueued, 1, 0) == 0, $"Unexpected _callbackQueued: {_callbackQueued}");
269264
#endif
270-
271-
ThreadPool.QueueUserWorkItem(o => ((AsyncOperation)o).InvokeCallback(), this);
265+
// We've marked the operation as canceled, and so should invoke the callback, but
266+
// we can't pool the object, as ProcessQueue may still have a reference to it, due to
267+
// using a pattern whereby it takes the lock to grab an item, but then releases the lock
268+
// to do further processing on the item that's still in the list.
269+
ThreadPool.QueueUserWorkItem(o => ((AsyncOperation)o).InvokeCallback(allowPooling: false), this);
272270
}
273271

274272
Trace("Exit");
@@ -289,7 +287,7 @@ public void DoAbort()
289287

290288
protected abstract bool DoTryComplete(SocketAsyncContext context);
291289

292-
protected abstract void InvokeCallback();
290+
public abstract void InvokeCallback(bool allowPooling);
293291

294292
[Conditional("SOCKETASYNCCONTEXT_TRACE")]
295293
public void Trace(string message, [CallerMemberName] string memberName = null)
@@ -332,7 +330,7 @@ public Action<int, byte[], int, SocketFlags, SocketError> Callback
332330
set => CallbackOrEvent = value;
333331
}
334332

335-
protected override void InvokeCallback() =>
333+
public override void InvokeCallback(bool allowPooling) =>
336334
((Action<int, byte[], int, SocketFlags, SocketError>)CallbackOrEvent)(BytesTransferred, SocketAddress, SocketAddressLen, SocketFlags.None, ErrorCode);
337335
}
338336

@@ -348,15 +346,18 @@ protected override bool DoTryComplete(SocketAsyncContext context)
348346
return SocketPal.TryCompleteSendTo(context._socket, Buffer.Span, null, ref bufferIndex, ref Offset, ref Count, Flags, SocketAddress, SocketAddressLen, ref BytesTransferred, out ErrorCode);
349347
}
350348

351-
protected override void InvokeCallback()
349+
public override void InvokeCallback(bool allowPooling)
352350
{
353351
var cb = (Action<int, byte[], int, SocketFlags, SocketError>)CallbackOrEvent;
354352
int bt = BytesTransferred;
355353
byte[] sa = SocketAddress;
356354
int sal = SocketAddressLen;
357355
SocketError ec = ErrorCode;
358356

359-
AssociatedContext.ReturnOperation(this);
357+
if (allowPooling)
358+
{
359+
AssociatedContext.ReturnOperation(this);
360+
}
360361

361362
cb(bt, sa, sal, SocketFlags.None, ec);
362363
}
@@ -374,15 +375,18 @@ protected override bool DoTryComplete(SocketAsyncContext context)
374375
return SocketPal.TryCompleteSendTo(context._socket, default(ReadOnlySpan<byte>), Buffers, ref BufferIndex, ref Offset, ref Count, Flags, SocketAddress, SocketAddressLen, ref BytesTransferred, out ErrorCode);
375376
}
376377

377-
protected override void InvokeCallback()
378+
public override void InvokeCallback(bool allowPooling)
378379
{
379380
var cb = (Action<int, byte[], int, SocketFlags, SocketError>)CallbackOrEvent;
380381
int bt = BytesTransferred;
381382
byte[] sa = SocketAddress;
382383
int sal = SocketAddressLen;
383384
SocketError ec = ErrorCode;
384385

385-
AssociatedContext.ReturnOperation(this);
386+
if (allowPooling)
387+
{
388+
AssociatedContext.ReturnOperation(this);
389+
}
386390

387391
cb(bt, sa, sal, SocketFlags.None, ec);
388392
}
@@ -416,7 +420,7 @@ public Action<int, byte[], int, SocketFlags, SocketError> Callback
416420
set => CallbackOrEvent = value;
417421
}
418422

419-
protected override void InvokeCallback() =>
423+
public override void InvokeCallback(bool allowPooling) =>
420424
((Action<int, byte[], int, SocketFlags, SocketError>)CallbackOrEvent)(
421425
BytesTransferred, SocketAddress, SocketAddressLen, ReceivedFlags, ErrorCode);
422426
}
@@ -430,7 +434,7 @@ public BufferMemoryReceiveOperation(SocketAsyncContext context) : base(context)
430434
protected override bool DoTryComplete(SocketAsyncContext context) =>
431435
SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress, ref SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode);
432436

433-
protected override void InvokeCallback()
437+
public override void InvokeCallback(bool allowPooling)
434438
{
435439
var cb = (Action<int, byte[], int, SocketFlags, SocketError>)CallbackOrEvent;
436440
int bt = BytesTransferred;
@@ -439,7 +443,10 @@ protected override void InvokeCallback()
439443
SocketFlags rf = ReceivedFlags;
440444
SocketError ec = ErrorCode;
441445

442-
AssociatedContext.ReturnOperation(this);
446+
if (allowPooling)
447+
{
448+
AssociatedContext.ReturnOperation(this);
449+
}
443450

444451
cb(bt, sa, sal, rf, ec);
445452
}
@@ -454,7 +461,7 @@ public BufferListReceiveOperation(SocketAsyncContext context) : base(context) {
454461
protected override bool DoTryComplete(SocketAsyncContext context) =>
455462
SocketPal.TryCompleteReceiveFrom(context._socket, default(Span<byte>), Buffers, Flags, SocketAddress, ref SocketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode);
456463

457-
protected override void InvokeCallback()
464+
public override void InvokeCallback(bool allowPooling)
458465
{
459466
var cb = (Action<int, byte[], int, SocketFlags, SocketError>)CallbackOrEvent;
460467
int bt = BytesTransferred;
@@ -463,7 +470,10 @@ protected override void InvokeCallback()
463470
SocketFlags rf = ReceivedFlags;
464471
SocketError ec = ErrorCode;
465472

466-
AssociatedContext.ReturnOperation(this);
473+
if (allowPooling)
474+
{
475+
AssociatedContext.ReturnOperation(this);
476+
}
467477

468478
cb(bt, sa, sal, rf, ec);
469479
}
@@ -504,7 +514,7 @@ public Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError> C
504514
protected override bool DoTryComplete(SocketAsyncContext context) =>
505515
SocketPal.TryCompleteReceiveMessageFrom(context._socket, Buffer.Span, Buffers, Flags, SocketAddress, ref SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode);
506516

507-
protected override void InvokeCallback() =>
517+
public override void InvokeCallback(bool allowPooling) =>
508518
((Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError>)CallbackOrEvent)(
509519
BytesTransferred, SocketAddress, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode);
510520
}
@@ -530,15 +540,18 @@ protected override bool DoTryComplete(SocketAsyncContext context)
530540
return completed;
531541
}
532542

533-
protected override void InvokeCallback()
543+
public override void InvokeCallback(bool allowPooling)
534544
{
535545
var cb = (Action<IntPtr, byte[], int, SocketError>)CallbackOrEvent;
536546
IntPtr fd = AcceptedFileDescriptor;
537547
byte[] sa = SocketAddress;
538548
int sal = SocketAddressLen;
539549
SocketError ec = ErrorCode;
540550

541-
AssociatedContext.ReturnOperation(this);
551+
if (allowPooling)
552+
{
553+
AssociatedContext.ReturnOperation(this);
554+
}
542555

543556
cb(fd, sa, sal, ec);
544557
}
@@ -562,7 +575,7 @@ protected override bool DoTryComplete(SocketAsyncContext context)
562575
return result;
563576
}
564577

565-
protected override void InvokeCallback() =>
578+
public override void InvokeCallback(bool allowPooling) =>
566579
((Action<SocketError>)CallbackOrEvent)(ErrorCode);
567580
}
568581

@@ -582,7 +595,7 @@ public Action<long, SocketError> Callback
582595
set => CallbackOrEvent = value;
583596
}
584597

585-
protected override void InvokeCallback() =>
598+
public override void InvokeCallback(bool allowPooling) =>
586599
((Action<long, SocketError>)CallbackOrEvent)(BytesTransferred, ErrorCode);
587600

588601
protected override bool DoTryComplete(SocketAsyncContext context) =>
@@ -944,7 +957,10 @@ public void ProcessQueue(SocketAsyncContext context)
944957
ThreadPool.QueueUserWorkItem(s_processingCallback, context);
945958
}
946959

947-
op.DoCallback();
960+
// At this point, the operation has completed and it's no longer
961+
// in the queue / no one else has a reference to it. We can invoke
962+
// the callback and let it pool the object if appropriate.
963+
op.InvokeCallback(allowPooling: true);
948964
}
949965
else
950966
{

src/System.Net.Sockets/tests/FunctionalTests/Accept.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,30 @@ public async Task Accept_WithInUseTargetSocket_Fails()
276276
Assert.Throws<InvalidOperationException>(() => { AcceptAsync(listener, server); });
277277
}
278278
}
279+
280+
[Fact]
281+
public async Task AcceptAsync_MultipleAcceptsThenDispose_AcceptsThrowAfterDispose()
282+
{
283+
if (UsesSync)
284+
{
285+
return;
286+
}
287+
288+
for (int i = 0; i < 100; i++)
289+
{
290+
using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
291+
{
292+
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
293+
listener.Listen(2);
294+
295+
Task accept1 = AcceptAsync(listener);
296+
Task accept2 = AcceptAsync(listener);
297+
listener.Dispose();
298+
await Assert.ThrowsAnyAsync<Exception>(() => accept1);
299+
await Assert.ThrowsAnyAsync<Exception>(() => accept2);
300+
}
301+
}
302+
}
279303
}
280304

281305
public sealed class AcceptSync : Accept<SocketHelperArraySync> { }

0 commit comments

Comments
 (0)