Skip to content

Commit

Permalink
Correct handling of thread switching with async
Browse files Browse the repository at this point in the history
  • Loading branch information
badrishc committed Aug 9, 2019
1 parent a486834 commit d72c92a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 48 deletions.
8 changes: 6 additions & 2 deletions cs/src/core/ClientSession/ClientSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ public bool CompletePending(bool spinWait = false)
public async ValueTask CompletePendingAsync()
{
ResumeThread();
await fht.CompletePendingAsync();
await fht.CompletePendingAsync(this);
ResumeThread();

SuspendThread();
}

Expand All @@ -177,7 +179,9 @@ public bool CompleteCheckpoint(bool spinWait = false)
internal async ValueTask CompleteCheckpointAsync()
{
ResumeThread();
await fht.CompleteCheckpointAsync();
await fht.CompleteCheckpointAsync(this);
ResumeThread();

SuspendThread();
}
}
Expand Down
64 changes: 29 additions & 35 deletions cs/src/core/ClientSession/FASTERAsync.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public partial class FasterKV<Key, Value, Input, Output, Context, Functions> : F
/// Complete outstanding pending operations
/// </summary>
/// <returns></returns>
internal async ValueTask CompletePendingAsync()
internal async ValueTask CompletePendingAsync(ClientSession<Key, Value, Input, Output, Context, Functions> clientSession)
{
do
{
Expand All @@ -32,10 +32,14 @@ internal async ValueTask CompletePendingAsync()
||
threadCtx.Value.phase == Phase.WAIT_PENDING)
{
await CompleteIOPendingRequestsAsync(prevThreadCtx.Value);
await CompleteIOPendingRequestsAsync(prevThreadCtx.Value, clientSession);
clientSession.ResumeThread();

Debug.Assert(prevThreadCtx.Value.ioPendingRequests.Count == 0);

await InternalRefreshAsync(clientSession);
clientSession.ResumeThread();

await InternalRefreshAsync();
CompleteRetryRequests(prevThreadCtx.Value);

done &= (prevThreadCtx.Value.ioPendingRequests.Count == 0);
Expand All @@ -47,10 +51,14 @@ internal async ValueTask CompletePendingAsync()
||
threadCtx.Value.phase == Phase.WAIT_PENDING))
{
await CompleteIOPendingRequestsAsync(threadCtx.Value);
await CompleteIOPendingRequestsAsync(threadCtx.Value, clientSession);
clientSession.ResumeThread();

Debug.Assert(threadCtx.Value.ioPendingRequests.Count == 0);
}
await InternalRefreshAsync();
await InternalRefreshAsync(clientSession);
clientSession.ResumeThread();

CompleteRetryRequests(threadCtx.Value);

done &= (threadCtx.Value.ioPendingRequests.Count == 0);
Expand All @@ -67,25 +75,28 @@ internal async ValueTask CompletePendingAsync()
/// Complete the ongoing checkpoint (if any)
/// </summary>
/// <returns></returns>
internal async ValueTask CompleteCheckpointAsync()
internal async ValueTask CompleteCheckpointAsync(ClientSession<Key, Value, Input, Output, Context, Functions> clientSession)
{
// Thread has an active session.
// So we need to constantly complete pending
// and refresh (done inside CompletePending)
// for the checkpoint to be proceed
do
{
await CompletePendingAsync();
await CompletePendingAsync(clientSession);
clientSession.ResumeThread();

if (_systemState.phase == Phase.REST)
{
await CompletePendingAsync();
await CompletePendingAsync(clientSession);
clientSession.ResumeThread();
return;
}
} while (true);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal async ValueTask InternalRefreshAsync()
internal async ValueTask InternalRefreshAsync(ClientSession<Key, Value, Input, Output, Context, Functions> clientSession)
{
epoch.ProtectAndDrain();

Expand All @@ -103,7 +114,8 @@ internal async ValueTask InternalRefreshAsync()
return;
}

await HandleCheckpointingPhasesAsync();
await HandleCheckpointingPhasesAsync(clientSession);
clientSession.ResumeThread();
}


Expand All @@ -121,7 +133,7 @@ private bool AtomicSwitch(FasterExecutionContext fromCtx, FasterExecutionContext
return false;
}

private async ValueTask HandleCheckpointingPhasesAsync()
private async ValueTask HandleCheckpointingPhasesAsync(ClientSession<Key, Value, Input, Output, Context, Functions> clientSession)
{
var previousState = SystemState.Make(threadCtx.Value.phase, threadCtx.Value.version);
var finalState = SystemState.Copy(ref _systemState);
Expand Down Expand Up @@ -190,15 +202,9 @@ private async ValueTask HandleCheckpointingPhasesAsync()

if (!IsIndexFuzzyCheckpointCompleted())
{
// Suspend
var prevThreadCtxCopy = prevThreadCtx.Value;
var threadCtxCopy = threadCtx.Value;
SuspendSession();

clientSession.SuspendThread();
await IsIndexFuzzyCheckpointCompletedAsync();

// Resume session
ResumeSession(prevThreadCtxCopy, threadCtxCopy);
clientSession.ResumeThread();
}
GlobalMoveToNextCheckpointState(currentState);

Expand Down Expand Up @@ -289,15 +295,9 @@ private async ValueTask HandleCheckpointingPhasesAsync()
{
Debug.Assert(_hybridLogCheckpoint.flushedSemaphore != null);

// Suspend
var prevThreadCtxCopy = prevThreadCtx.Value;
var threadCtxCopy = threadCtx.Value;
SuspendSession();

clientSession.SuspendThread();
await _hybridLogCheckpoint.flushedSemaphore.WaitAsync();

// Resume session
ResumeSession(prevThreadCtxCopy, threadCtxCopy);
clientSession.ResumeThread();

_hybridLogCheckpoint.flushedSemaphore.Release();

Expand All @@ -310,15 +310,9 @@ private async ValueTask HandleCheckpointingPhasesAsync()

if (!notify)
{
// Suspend
var prevThreadCtxCopy = prevThreadCtx.Value;
var threadCtxCopy = threadCtx.Value;
SuspendSession();

clientSession.SuspendThread();
await IsIndexFuzzyCheckpointCompletedAsync();

// Resume session
ResumeSession(prevThreadCtxCopy, threadCtxCopy);
clientSession.ResumeThread();

notify = true;
}
Expand Down
14 changes: 3 additions & 11 deletions cs/src/core/Index/FASTER/FASTERThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ internal void CompleteIOPendingRequests(FasterExecutionContext context)
}
}

internal async ValueTask CompleteIOPendingRequestsAsync(FasterExecutionContext context, CancellationToken token = default(CancellationToken))
internal async ValueTask CompleteIOPendingRequestsAsync(FasterExecutionContext context, ClientSession<Key, Value, Input, Output, Context, Functions> clientSession, CancellationToken token = default(CancellationToken))
{
while (context.ioPendingRequests.Count > 0)
{
Expand All @@ -233,17 +233,9 @@ internal void CompleteIOPendingRequests(FasterExecutionContext context)
}
else
{
// Save context on continuation stack (from thread local)
var prevThreadCtxCopy = prevThreadCtx.Value;
var threadCtxCopy = threadCtx.Value;

// Suspend epoch
SuspendSession();

clientSession.SuspendThread();
request = await context.readyResponses.DequeueAsync(token);

// Resume session
ResumeSession(prevThreadCtxCopy, threadCtxCopy);
clientSession.ResumeThread();
}

InternalContinuePendingRequestAndCallback(context, request);
Expand Down

0 comments on commit d72c92a

Please sign in to comment.