Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 51 additions & 15 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,11 @@ private async Task CleanupConnectionAsync(List<Exception>? errors)
{
try
{
if (!childProcess.HasExited) childProcess.Kill();
if (!childProcess.HasExited)
{
childProcess.Kill(entireProcessTree: true);
Comment thread
stephentoub marked this conversation as resolved.
Comment thread
stephentoub marked this conversation as resolved.
await childProcess.WaitForExitAsync();
Comment thread
stephentoub marked this conversation as resolved.
Comment thread
stephentoub marked this conversation as resolved.
}
childProcess.Dispose();
}
catch (Exception ex) { errors?.Add(ex); }
Expand Down Expand Up @@ -1090,7 +1094,7 @@ internal static async Task<T> InvokeRpcAsync<T>(JsonRpc rpc, string method, obje

if (!string.IsNullOrEmpty(stderrOutput))
{
throw new IOException($"CLI process exited unexpectedly.\nstderr: {stderrOutput}", ex);
throw new IOException(FormatCliExitedMessage("CLI process exited unexpectedly.", stderrOutput), ex);
}
throw new IOException($"Communication error with Copilot CLI: {ex.Message}", ex);
}
Expand All @@ -1100,6 +1104,24 @@ internal static async Task<T> InvokeRpcAsync<T>(JsonRpc rpc, string method, obje
}
}

private static string FormatCliExitedMessage(string message, string stderrOutput)
{
return string.IsNullOrEmpty(stderrOutput)
? message
: $"{message}\nstderr: {stderrOutput}";
}

private static IOException CreateCliExitedException(string message, StringBuilder stderrBuffer)
{
string stderrOutput;
lock (stderrBuffer)
{
stderrOutput = stderrBuffer.ToString().Trim();
}

return new IOException(FormatCliExitedMessage(message, stderrOutput));
}

private Task<Connection> EnsureConnectedAsync(CancellationToken cancellationToken)
{
if (_connectionTask is null && !_options.AutoStart)
Expand Down Expand Up @@ -1152,7 +1174,7 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio
connection.Rpc, "connect", [new ConnectRequest { Token = _effectiveConnectionToken }], connection.StderrBuffer, cancellationToken);
serverVersion = (int)connectResponse.ProtocolVersion;
}
catch (RemoteRpcException ex) when (ex.ErrorCode == RemoteRpcException.MethodNotFoundErrorCode)
catch (IOException ex) when (ex.InnerException is RemoteRpcException remoteEx && IsUnsupportedConnectMethod(remoteEx))
{
// Legacy server without `connect`; fall back to `ping`. A token, if any,
// is silently dropped — the legacy server can't enforce one.
Expand Down Expand Up @@ -1180,6 +1202,12 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio
_negotiatedProtocolVersion = serverVersion.Value;
}

private static bool IsUnsupportedConnectMethod(RemoteRpcException ex)
Comment thread
stephentoub marked this conversation as resolved.
{
return ex.ErrorCode == RemoteRpcException.MethodNotFoundErrorCode
|| string.Equals(ex.Message, "Unhandled method connect", StringComparison.Ordinal);
}

private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, string? connectionToken, ILogger logger, CancellationToken cancellationToken)
{
// Use explicit path, COPILOT_CLI_PATH env var (from options.Environment or process env), or bundled CLI - no PATH fallback
Expand Down Expand Up @@ -1282,22 +1310,24 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio

// Capture stderr for error messages and forward to logger
var stderrBuffer = new StringBuilder();
_ = Task.Run(async () =>
var stderrReader = Task.Run(async () =>
{
while (cliProcess != null && !cliProcess.HasExited)
while (true)
{
var line = await cliProcess.StandardError.ReadLineAsync(cancellationToken);
if (line != null)
if (line is null)
{
lock (stderrBuffer)
{
stderrBuffer.AppendLine(line);
}
break;
}

if (logger.IsEnabled(LogLevel.Debug))
{
logger.LogDebug("[CLI] {Line}", line);
}
lock (stderrBuffer)
{
stderrBuffer.AppendLine(line);
}

if (logger.IsEnabled(LogLevel.Debug))
{
logger.LogDebug("[CLI] {Line}", line);
}
}
}, cancellationToken);
Expand All @@ -1311,7 +1341,13 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio

while (!cts.Token.IsCancellationRequested)
{
var line = await cliProcess.StandardOutput.ReadLineAsync(cts.Token) ?? throw new IOException("CLI process exited unexpectedly");
var line = await cliProcess.StandardOutput.ReadLineAsync(cts.Token);
if (line is null)
{
await stderrReader;
throw CreateCliExitedException("CLI process exited unexpectedly", stderrBuffer);
}

if (ListeningOnPortRegex().Match(line) is { Success: true } match)
{
detectedLocalhostTcpPort = int.Parse(match.Groups[1].Value, CultureInfo.InvariantCulture);
Expand Down
138 changes: 138 additions & 0 deletions dotnet/test/E2E/AbortE2ETests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
*--------------------------------------------------------------------------------------------*/

using System.ComponentModel;
using GitHub.Copilot.SDK.Test.Harness;
using Microsoft.Extensions.AI;
using Xunit;
using Xunit.Abstractions;

namespace GitHub.Copilot.SDK.Test.E2E;

/// <summary>
/// Verifies that <see cref="CopilotSession.AbortAsync"/> cleanly interrupts an active
/// turn — both during streaming and during tool execution — without leaving dangling
/// state or causing exceptions in the event delivery pipeline.
/// </summary>
public class AbortE2ETests(E2ETestFixture fixture, ITestOutputHelper output)
: E2ETestBase(fixture, "abort", output)
{
[Fact]
public async Task Should_Abort_During_Active_Streaming()
{
var session = await CreateSessionAsync(new SessionConfig { Streaming = true });

var firstDeltaReceived = new TaskCompletionSource<AssistantMessageDeltaEvent>(TaskCreationOptions.RunContinuationsAsynchronously);
var allEvents = new List<SessionEvent>();

session.On(evt =>
{
lock (allEvents) { allEvents.Add(evt); }
if (evt is AssistantMessageDeltaEvent delta)
{
firstDeltaReceived.TrySetResult(delta);
}
});

// Fire-and-forget — we'll abort before it finishes
_ = session.SendAsync(new MessageOptions
{
Prompt = "Write a very long essay about the history of computing, covering every decade from the 1940s to the 2020s in great detail.",
});

// Wait for at least one delta to arrive (proves streaming started)
var delta = await firstDeltaReceived.Task.WaitAsync(TimeSpan.FromSeconds(60));
Assert.False(string.IsNullOrEmpty(delta.Data.DeltaContent));

// Now abort mid-stream
await session.AbortAsync();

List<SessionEvent> snapshot;
lock (allEvents) { snapshot = [.. allEvents]; }

// No session.idle should have appeared (abort cancels the turn)
// OR if idle DID appear, it should be after the abort, which is fine
// The key contract: no exceptions were thrown, and the session is usable afterwards
var types = snapshot.Select(e => e.Type).ToList();
Assert.Contains("assistant.message_delta", types);

// Session should be usable after abort — verify by listening for the
// recovery message rather than racing against a late idle from the
// aborted streaming turn.
var recoveryReceived = new TaskCompletionSource<AssistantMessageEvent>(TaskCreationOptions.RunContinuationsAsynchronously);
session.On(evt =>
{
if (evt is AssistantMessageEvent msg && (msg.Data.Content?.Contains("abort_recovery_ok") == true))
{
recoveryReceived.TrySetResult(msg);
}
});

await session.SendAsync(new MessageOptions
{
Prompt = "Say 'abort_recovery_ok'.",
});

var recoveryMessage = await recoveryReceived.Task.WaitAsync(TimeSpan.FromSeconds(60));
Assert.Contains("abort_recovery_ok", recoveryMessage.Data.Content?.ToLowerInvariant() ?? string.Empty);

await session.DisposeAsync();
}

[Fact]
public async Task Should_Abort_During_Active_Tool_Execution()
{
var toolStarted = new TaskCompletionSource<string>(TaskCreationOptions.RunContinuationsAsynchronously);
var releaseTool = new TaskCompletionSource<string>(TaskCreationOptions.RunContinuationsAsynchronously);

var session = await CreateSessionAsync(new SessionConfig
{
Tools = [AIFunctionFactory.Create(SlowTool, "slow_analysis")],
OnPermissionRequest = PermissionHandler.ApproveAll,
});

// Fire-and-forget
_ = session.SendAsync(new MessageOptions
{
Prompt = "Use slow_analysis with value 'test_abort'. Wait for the result.",
});

// Wait for the tool to start executing
var toolValue = await toolStarted.Task.WaitAsync(TimeSpan.FromSeconds(60));
Assert.Equal("test_abort", toolValue);

// Abort while the tool is running
await session.AbortAsync();

// Release the tool so its task doesn't leak
releaseTool.TrySetResult("RELEASED_AFTER_ABORT");

// Session should be usable after abort — verify by listening for the right event
var recoveryReceived = new TaskCompletionSource<AssistantMessageEvent>(TaskCreationOptions.RunContinuationsAsynchronously);
session.On(evt =>
{
if (evt is AssistantMessageEvent msg && (msg.Data.Content?.Contains("tool_abort_recovery_ok") == true))
{
recoveryReceived.TrySetResult(msg);
}
});

await session.SendAsync(new MessageOptions
{
Prompt = "Say 'tool_abort_recovery_ok'.",
});

var recoveryMessage = await recoveryReceived.Task.WaitAsync(TimeSpan.FromSeconds(60));
Assert.Contains("tool_abort_recovery_ok", recoveryMessage.Data.Content?.ToLowerInvariant() ?? string.Empty);

await session.DisposeAsync();

[Description("A slow analysis tool that blocks until released")]
async Task<string> SlowTool([Description("Value to analyze")] string value)
{
toolStarted.TrySetResult(value);
return await releaseTool.Task;
}
}
}
Loading
Loading