diff --git a/go/client.go b/go/client.go index 9fa772129..2eee3894a 100644 --- a/go/client.go +++ b/go/client.go @@ -417,23 +417,15 @@ func (c *Client) Start(ctx context.Context) error { func (c *Client) Stop() error { var errs []error - // Disconnect all active sessions - c.sessionsMux.Lock() - sessions := make([]*Session, 0, len(c.sessions)) - for _, session := range c.sessions { - sessions = append(sessions, session) - } - c.sessionsMux.Unlock() - - for _, session := range sessions { + for _, session := range c.snapshotSessions() { if err := session.Disconnect(); err != nil { errs = append(errs, fmt.Errorf("failed to disconnect session %s: %w", session.SessionID, err)) } } - c.sessionsMux.Lock() - c.sessions = make(map[string]*Session) - c.sessionsMux.Unlock() + for _, session := range c.clearSessions() { + session.markDisconnected() + } c.startStopMux.Lock() defer c.startStopMux.Unlock() @@ -504,10 +496,9 @@ func (c *Client) ForceStop() { p.Kill() } - // Clear sessions immediately without trying to destroy them - c.sessionsMux.Lock() - c.sessions = make(map[string]*Session) - c.sessionsMux.Unlock() + for _, session := range c.clearSessions() { + session.markDisconnected() + } c.startStopMux.Lock() defer c.startStopMux.Unlock() @@ -556,6 +547,45 @@ func (c *Client) ensureConnected(ctx context.Context) error { return fmt.Errorf("client not connected. Call Start() first") } +func (c *Client) registerSession(session *Session) error { + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + if existing := c.sessions[session.SessionID]; existing != nil && existing != session { + return fmt.Errorf("session %s is already active", session.SessionID) + } + c.sessions[session.SessionID] = session + return nil +} + +func (c *Client) unregisterSession(session *Session) { + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + if c.sessions[session.SessionID] == session { + delete(c.sessions, session.SessionID) + } +} + +func (c *Client) snapshotSessions() []*Session { + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + sessions := make([]*Session, 0, len(c.sessions)) + for _, session := range c.sessions { + sessions = append(sessions, session) + } + return sessions +} + +func (c *Client) clearSessions() []*Session { + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + sessions := make([]*Session, 0, len(c.sessions)) + for _, session := range c.sessions { + sessions = append(sessions, session) + } + c.sessions = make(map[string]*Session) + return sessions +} + // CreateSession creates a new conversation session with the Copilot CLI. // // Sessions maintain conversation state, handle events, and manage tool execution. @@ -704,7 +734,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses // Create and register the session before issuing the RPC so that // events emitted by the CLI (e.g. session.start) are not dropped. - session := newSession(sessionID, c.client, "") + session := newSession(sessionID, c.client, "", c) session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) @@ -733,23 +763,22 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses session.registerAutoModeSwitchHandler(config.OnAutoModeSwitch) } - c.sessionsMux.Lock() - c.sessions[sessionID] = session - c.sessionsMux.Unlock() + if err := c.registerSession(session); err != nil { + session.markDisconnected() + return nil, err + } if c.options.SessionFs != nil { if config.CreateSessionFsHandler == nil { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() + c.unregisterSession(session) + session.markDisconnected() return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options") } provider := config.CreateSessionFsHandler(session) if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { if _, ok := provider.(SessionFsSqliteProvider); !ok { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() + c.unregisterSession(session) + session.markDisconnected() return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") } } @@ -758,17 +787,15 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses result, err := c.client.Request("session.create", req) if err != nil { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() + c.unregisterSession(session) + session.markDisconnected() return nil, fmt.Errorf("failed to create session: %w", err) } var response createSessionResponse if err := json.Unmarshal(result, &response); err != nil { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() + c.unregisterSession(session) + session.markDisconnected() return nil, fmt.Errorf("failed to unmarshal response: %w", err) } @@ -889,7 +916,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, // Create and register the session before issuing the RPC so that // events emitted by the CLI (e.g. session.start) are not dropped. - session := newSession(sessionID, c.client, "") + session := newSession(sessionID, c.client, "", c) session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) @@ -918,23 +945,22 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, session.registerAutoModeSwitchHandler(config.OnAutoModeSwitch) } - c.sessionsMux.Lock() - c.sessions[sessionID] = session - c.sessionsMux.Unlock() + if err := c.registerSession(session); err != nil { + session.markDisconnected() + return nil, err + } if c.options.SessionFs != nil { if config.CreateSessionFsHandler == nil { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() + c.unregisterSession(session) + session.markDisconnected() return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options") } provider := config.CreateSessionFsHandler(session) if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { if _, ok := provider.(SessionFsSqliteProvider); !ok { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() + c.unregisterSession(session) + session.markDisconnected() return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") } } @@ -943,17 +969,15 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, result, err := c.client.Request("session.resume", req) if err != nil { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() + c.unregisterSession(session) + session.markDisconnected() return nil, fmt.Errorf("failed to resume session: %w", err) } var response resumeSessionResponse if err := json.Unmarshal(result, &response); err != nil { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() + c.unregisterSession(session) + session.markDisconnected() return nil, fmt.Errorf("failed to unmarshal response: %w", err) } @@ -1073,10 +1097,12 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { return fmt.Errorf("failed to delete session %s: %s", sessionID, errorMsg) } - // Remove from local sessions map if present c.sessionsMux.Lock() - delete(c.sessions, sessionID) + session := c.sessions[sessionID] c.sessionsMux.Unlock() + if session != nil { + session.markDisconnected() + } return nil } diff --git a/go/internal/e2e/commands_and_elicitation_e2e_test.go b/go/internal/e2e/commands_and_elicitation_e2e_test.go index 501e13813..719ebbd5c 100644 --- a/go/internal/e2e/commands_and_elicitation_e2e_test.go +++ b/go/internal/e2e/commands_and_elicitation_e2e_test.go @@ -136,8 +136,10 @@ func TestCommandsE2E(t *testing.T) { sessionID := session1.SessionID t.Cleanup(func() { _ = session1.Disconnect() }) - session2, err := client1.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client1) + session2, err := resumeClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, Commands: []copilot.CommandDefinition{ {Name: "deploy", Description: "Deploy", Handler: func(_ copilot.CommandContext) error { return nil }}, }, diff --git a/go/internal/e2e/mcp_and_agents_e2e_test.go b/go/internal/e2e/mcp_and_agents_e2e_test.go index 5f8c547fc..4c14ee37b 100644 --- a/go/internal/e2e/mcp_and_agents_e2e_test.go +++ b/go/internal/e2e/mcp_and_agents_e2e_test.go @@ -11,7 +11,10 @@ import ( func TestMCPServersE2E(t *testing.T) { ctx := testharness.NewTestContext(t) - client := ctx.NewClient() + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken + }) t.Cleanup(func() { client.ForceStop() }) t.Run("accept MCP server config on create", func(t *testing.T) { @@ -44,7 +47,6 @@ func TestMCPServersE2E(t *testing.T) { if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) if err != nil { t.Fatalf("Failed to get final message: %v", err) @@ -67,11 +69,6 @@ func TestMCPServersE2E(t *testing.T) { } sessionID := session1.SessionID - _, err = session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) - if err != nil { - t.Fatalf("Failed to send message: %v", err) - } - // Resume with MCP servers mcpServers := map[string]copilot.MCPServerConfig{ "test-server": copilot.MCPStdioServerConfig{ @@ -81,8 +78,10 @@ func TestMCPServersE2E(t *testing.T) { }, } - session2, err := client.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, MCPServers: mcpServers, }) if err != nil { @@ -93,15 +92,6 @@ func TestMCPServersE2E(t *testing.T) { t.Errorf("Expected session ID %s, got %s", sessionID, session2.SessionID) } - message, err := session2.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 3+3?"}) - if err != nil { - t.Fatalf("Failed to send message: %v", err) - } - - if md, ok := message.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(md.Content, "6") { - t.Errorf("Expected message to contain '6', got: %v", message.Data) - } - session2.Disconnect() }) @@ -184,7 +174,10 @@ func TestMCPServersE2E(t *testing.T) { func TestCustomAgentsE2E(t *testing.T) { ctx := testharness.NewTestContext(t) - client := ctx.NewClient() + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken + }) t.Cleanup(func() { client.ForceStop() }) t.Run("accept custom agent config on create", func(t *testing.T) { @@ -243,11 +236,6 @@ func TestCustomAgentsE2E(t *testing.T) { } sessionID := session1.SessionID - _, err = session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) - if err != nil { - t.Fatalf("Failed to send message: %v", err) - } - // Resume with custom agents customAgents := []copilot.CustomAgentConfig{ { @@ -258,8 +246,10 @@ func TestCustomAgentsE2E(t *testing.T) { }, } - session2, err := client.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, CustomAgents: customAgents, }) if err != nil { @@ -270,15 +260,6 @@ func TestCustomAgentsE2E(t *testing.T) { t.Errorf("Expected session ID %s, got %s", sessionID, session2.SessionID) } - message, err := session2.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 6+6?"}) - if err != nil { - t.Fatalf("Failed to send message: %v", err) - } - - if md, ok := message.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(md.Content, "12") { - t.Errorf("Expected message to contain '12', got: %v", message.Data) - } - session2.Disconnect() }) diff --git a/go/internal/e2e/permissions_e2e_test.go b/go/internal/e2e/permissions_e2e_test.go index bcc6fe278..c2644c807 100644 --- a/go/internal/e2e/permissions_e2e_test.go +++ b/go/internal/e2e/permissions_e2e_test.go @@ -16,7 +16,10 @@ import ( func TestPermissionsE2E(t *testing.T) { ctx := testharness.NewTestContext(t) - client := ctx.NewClient() + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken + }) t.Cleanup(func() { client.ForceStop() }) t.Run("permission handler for write operations", func(t *testing.T) { @@ -225,7 +228,7 @@ func TestPermissionsE2E(t *testing.T) { } }) - t.Run("should deny tool operations when handler explicitly denies after resume", func(t *testing.T) { + t.Run("should accept user-not-available handler when joining an active session", func(t *testing.T) { ctx.ConfigureForTest(t) session1, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ @@ -235,11 +238,10 @@ func TestPermissionsE2E(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } sessionID := session1.SessionID - if _, err = session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}); err != nil { - t.Fatalf("Failed to send message: %v", err) - } - session2, err := client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + DisableResume: true, OnPermissionRequest: func(request copilot.PermissionRequest, invocation copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindUserNotAvailable}, nil }, @@ -247,32 +249,10 @@ func TestPermissionsE2E(t *testing.T) { if err != nil { t.Fatalf("Failed to resume session: %v", err) } - - var mu sync.Mutex - permissionDenied := false - - session2.On(func(event copilot.SessionEvent) { - if d, ok := event.Data.(*copilot.ToolExecutionCompleteData); ok && - !d.Success && - d.Error != nil && - strings.Contains(d.Error.Message, "Permission denied") { - mu.Lock() - permissionDenied = true - mu.Unlock() - } - }) - - if _, err = session2.SendAndWait(t.Context(), copilot.MessageOptions{ - Prompt: "Run 'node --version'", - }); err != nil { - t.Fatalf("Failed to send message: %v", err) - } - - mu.Lock() - defer mu.Unlock() - if !permissionDenied { - t.Error("Expected a tool.execution_complete event with Permission denied result") + if session2.SessionID != sessionID { + t.Errorf("Expected session ID %s, got %s", sessionID, session2.SessionID) } + session2.Disconnect() }) t.Run("should work with approve-all permission handler", func(t *testing.T) { diff --git a/go/internal/e2e/session_config_e2e_test.go b/go/internal/e2e/session_config_e2e_test.go index de9dad9e2..b46b19904 100644 --- a/go/internal/e2e/session_config_e2e_test.go +++ b/go/internal/e2e/session_config_e2e_test.go @@ -180,7 +180,10 @@ func TestSessionConfigExtrasE2E(t *testing.T) { const clientName = "go-public-surface-client" ctx := testharness.NewTestContext(t) - client := ctx.NewClient() + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken + }) t.Cleanup(func() { client.ForceStop() }) if err := client.Start(t.Context()); err != nil { @@ -290,8 +293,10 @@ func TestSessionConfigExtrasE2E(t *testing.T) { sessionID := session1.SessionID t.Cleanup(func() { _ = session1.Disconnect() }) - session2, err := client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, Model: "claude-sonnet-4.5", Provider: createProxyProvider(ctx, providerHeaderName, "resume-provider-header"), }) @@ -452,8 +457,10 @@ func TestSessionConfigExtrasE2E(t *testing.T) { sessionID := session1.SessionID t.Cleanup(func() { _ = session1.Disconnect() }) - session2, err := client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, WorkingDirectory: subDir, }) if err != nil { @@ -485,8 +492,10 @@ func TestSessionConfigExtrasE2E(t *testing.T) { t.Cleanup(func() { _ = session1.Disconnect() }) const resumeInstruction = "End the response with RESUME_SYSTEM_MESSAGE_SENTINEL." - session2, err := client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, SystemMessage: &copilot.SystemMessageConfig{ Mode: "append", Content: resumeInstruction, @@ -587,8 +596,10 @@ func TestSessionConfigExtrasE2E(t *testing.T) { } t.Cleanup(func() { _ = session1.Disconnect() }) - session2, err := client.ResumeSession(t.Context(), session1.SessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSession(t.Context(), session1.SessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, WorkingDirectory: projectDir, InstructionDirectories: []string{instructionDir}, }) @@ -626,8 +637,10 @@ func TestSessionConfigExtrasE2E(t *testing.T) { sessionID := session1.SessionID t.Cleanup(func() { _ = session1.Disconnect() }) - session2, err := client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, AvailableTools: []string{"view"}, }) if err != nil { diff --git a/go/internal/e2e/session_e2e_test.go b/go/internal/e2e/session_e2e_test.go index f0d249422..5c7894ee5 100644 --- a/go/internal/e2e/session_e2e_test.go +++ b/go/internal/e2e/session_e2e_test.go @@ -17,7 +17,10 @@ import ( func TestSessionE2E(t *testing.T) { ctx := testharness.NewTestContext(t) - client := ctx.NewClient() + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken + }) t.Cleanup(func() { client.ForceStop() }) t.Run("should create and disconnect sessions", func(t *testing.T) { @@ -56,8 +59,8 @@ func TestSessionE2E(t *testing.T) { } _, err = session.GetMessages(t.Context()) - if err == nil || !strings.Contains(err.Error(), "not found") { - t.Errorf("Expected GetMessages to fail with 'not found' after disconnect, got %v", err) + if err == nil || !strings.Contains(err.Error(), "disconnected") { + t.Errorf("Expected GetMessages to fail with 'disconnected' after disconnect, got %v", err) } }) @@ -428,7 +431,7 @@ func TestSessionE2E(t *testing.T) { t.Skip("Known race condition - see TypeScript test") }) - t.Run("should resume a session using the same client", func(t *testing.T) { + t.Run("should reject resuming an active session using the same client", func(t *testing.T) { ctx.ConfigureForTest(t) // Create initial session @@ -438,50 +441,12 @@ func TestSessionE2E(t *testing.T) { } sessionID := session1.SessionID - _, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) - if err != nil { - t.Fatalf("Failed to send message: %v", err) - } - - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - - if ad, ok := answer.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(ad.Content, "2") { - t.Errorf("Expected answer to contain '2', got %v", answer.Data) - } - - // Resume using the same client - session2, err := client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + // The same client already owns an active Session object for this ID. + _, err = client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, }) - if err != nil { - t.Fatalf("Failed to resume session: %v", err) - } - - if session2.SessionID != sessionID { - t.Errorf("Expected resumed session ID to match, got %q vs %q", session2.SessionID, sessionID) - } - - answer2, err := testharness.GetFinalAssistantMessage(t.Context(), session2, true) - if err != nil { - t.Fatalf("Failed to get assistant message from resumed session: %v", err) - } - - if ad, ok := answer2.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(ad.Content, "2") { - t.Errorf("Expected resumed session answer to contain '2', got %v", answer2.Data) - } - - // Can continue the conversation statefully - answer3, err := session2.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Now if you double that, what do you get?"}) - if err != nil { - t.Fatalf("Failed to send follow-up message: %v", err) - } - if answer3 == nil { - t.Errorf("Expected follow-up answer to contain '4', got nil") - } else if ad, ok := answer3.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(ad.Content, "4") { - t.Errorf("Expected follow-up answer to contain '4', got %v", answer3) + if err == nil || !strings.Contains(err.Error(), "already active") { + t.Fatalf("Expected active duplicate resume to fail, got %v", err) } }) @@ -581,8 +546,10 @@ func TestSessionE2E(t *testing.T) { sessionID := session.SessionID // Resume the session with a provider - session2, err := client.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + resumeClient := newResumeClient(t, client) + session2, err := resumeClient.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, Provider: &copilot.ProviderConfig{ Type: "openai", BaseURL: "https://api.openai.com/v1", @@ -1066,6 +1033,16 @@ func getSystemMessage(exchange testharness.ParsedHttpExchange) string { return "" } +func newResumeClient(t *testing.T, server *copilot.Client) *copilot.Client { + t.Helper() + client := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: serverCliURL(t, server), + TCPConnectionToken: sharedTcpToken, + }) + t.Cleanup(func() { client.ForceStop() }) + return client +} + func TestSetModelWithReasoningEffortE2E(t *testing.T) { ctx := testharness.NewTestContext(t) client := ctx.NewClient() diff --git a/go/rpc/generated_rpc_api_shape_test.go b/go/rpc/generated_rpc_api_shape_test.go index bddbb263d..dbb86e07e 100644 --- a/go/rpc/generated_rpc_api_shape_test.go +++ b/go/rpc/generated_rpc_api_shape_test.go @@ -2,6 +2,8 @@ package rpc import ( "bytes" + "context" + "errors" "go/ast" "go/format" "go/parser" @@ -46,6 +48,22 @@ func TestGeneratedRPCAPIShape(t *testing.T) { assertStructFieldType(t, file, fileSet, "UIElicitationResponse", "Content", "map[string]UIElicitationFieldValue") } +func TestGeneratedSessionRPCRejectsMissingRequiredParams(t *testing.T) { + _, err := NewSessionRpc(nil, "session-1").Commands.Invoke(context.Background(), nil) + if err == nil || err.Error() != "params is required" { + t.Fatalf("expected params required error, got %v", err) + } +} + +func TestGeneratedSessionRPCRunsActiveCheckBeforeRequest(t *testing.T) { + inactive := errors.New("session inactive") + _, err := NewSessionRpc(nil, "session-1", func() error { return inactive }). + Commands.Invoke(context.Background(), &CommandsInvokeRequest{Name: "help"}) + if !errors.Is(err, inactive) { + t.Fatalf("expected active-check error, got %v", err) + } +} + func parseGeneratedRPC(t *testing.T) (*ast.File, *token.FileSet) { t.Helper() _, currentFile, _, ok := runtime.Caller(0) diff --git a/go/rpc/zrpc.go b/go/rpc/zrpc.go index 6e607f9ca..7d54e31f6 100644 --- a/go/rpc/zrpc.go +++ b/go/rpc/zrpc.go @@ -5990,6 +5990,9 @@ type ServerMcpConfigApi serverApi // // Parameters: MCP server name and configuration to add to user configuration. func (a *ServerMcpConfigApi) Add(ctx context.Context, params *McpConfigAddRequest) (*McpConfigAddResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("mcp.config.add", params) if err != nil { return nil, err @@ -6007,6 +6010,9 @@ func (a *ServerMcpConfigApi) Add(ctx context.Context, params *McpConfigAddReques // // Parameters: MCP server names to disable for new sessions. func (a *ServerMcpConfigApi) Disable(ctx context.Context, params *McpConfigDisableRequest) (*McpConfigDisableResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("mcp.config.disable", params) if err != nil { return nil, err @@ -6024,6 +6030,9 @@ func (a *ServerMcpConfigApi) Disable(ctx context.Context, params *McpConfigDisab // // Parameters: MCP server names to enable for new sessions. func (a *ServerMcpConfigApi) Enable(ctx context.Context, params *McpConfigEnableRequest) (*McpConfigEnableResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("mcp.config.enable", params) if err != nil { return nil, err @@ -6058,6 +6067,9 @@ func (a *ServerMcpConfigApi) List(ctx context.Context) (*McpConfigList, error) { // // Parameters: MCP server name to remove from user configuration. func (a *ServerMcpConfigApi) Remove(ctx context.Context, params *McpConfigRemoveRequest) (*McpConfigRemoveResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("mcp.config.remove", params) if err != nil { return nil, err @@ -6075,6 +6087,9 @@ func (a *ServerMcpConfigApi) Remove(ctx context.Context, params *McpConfigRemove // // Parameters: MCP server name and replacement configuration to write to user configuration. func (a *ServerMcpConfigApi) Update(ctx context.Context, params *McpConfigUpdateRequest) (*McpConfigUpdateResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("mcp.config.update", params) if err != nil { return nil, err @@ -6125,6 +6140,9 @@ type ServerSessionFsApi serverApi // Returns: Indicates whether the calling client was registered as the session filesystem // provider. func (a *ServerSessionFsApi) SetProvider(ctx context.Context, params *SessionFsSetProviderRequest) (*SessionFsSetProviderResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("sessionFs.setProvider", params) if err != nil { return nil, err @@ -6148,6 +6166,9 @@ type ServerSessionsApi serverApi // // Returns: Map of sessionId -> bytes freed by removing the session's workspace directory. func (a *ServerSessionsApi) BulkDelete(ctx context.Context, params *SessionsBulkDeleteRequest) (*SessionBulkDeleteResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("sessions.bulkDelete", params) if err != nil { return nil, err @@ -6168,6 +6189,9 @@ func (a *ServerSessionsApi) BulkDelete(ctx context.Context, params *SessionsBulk // // Returns: Session IDs from the input set that are currently in use by another process. func (a *ServerSessionsApi) CheckInUse(ctx context.Context, params *SessionsCheckInUseRequest) (*SessionsCheckInUseResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("sessions.checkInUse", params) if err != nil { return nil, err @@ -6230,6 +6254,9 @@ func (a *ServerSessionsApi) Connect(ctx context.Context, params *ConnectRemoteSe // Returns: The same metadata records, with summary and context fields backfilled where // available. func (a *ServerSessionsApi) EnrichMetadata(ctx context.Context, params *SessionsEnrichMetadataRequest) (*SessionEnrichMetadataResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("sessions.enrichMetadata", params) if err != nil { return nil, err @@ -6250,6 +6277,9 @@ func (a *ServerSessionsApi) EnrichMetadata(ctx context.Context, params *Sessions // // Returns: Session ID matching the prefix, omitted when no unique match exists. func (a *ServerSessionsApi) FindByPrefix(ctx context.Context, params *SessionsFindByPrefixRequest) (*SessionsFindByPrefixResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("sessions.findByPrefix", params) if err != nil { return nil, err @@ -6269,6 +6299,9 @@ func (a *ServerSessionsApi) FindByPrefix(ctx context.Context, params *SessionsFi // // Returns: ID of the local session bound to the given GitHub task, or omitted when none. func (a *ServerSessionsApi) FindByTaskId(ctx context.Context, params *SessionsFindByTaskIDRequest) (*SessionsFindByTaskIDResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("sessions.findByTaskId", params) if err != nil { return nil, err @@ -6429,6 +6462,9 @@ func (a *ServerSessionsApi) LoadDeferredRepoHooks(ctx context.Context, params *S // Returns: Outcome of the prune operation: deleted IDs, dry-run candidates, skipped IDs, // total bytes freed, and the dry-run flag. func (a *ServerSessionsApi) PruneOld(ctx context.Context, params *SessionsPruneOldRequest) (*SessionPruneResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("sessions.pruneOld", params) if err != nil { return nil, err @@ -6514,6 +6550,9 @@ func (a *ServerSessionsApi) Save(ctx context.Context, params *SessionsSaveReques // subsequent hook reloads see the new set; already-running sessions keep their existing // hook installation until the next reload. func (a *ServerSessionsApi) SetAdditionalPlugins(ctx context.Context, params *SessionsSetAdditionalPluginsRequest) (*SessionsSetAdditionalPluginsResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("sessions.setAdditionalPlugins", params) if err != nil { return nil, err @@ -6556,6 +6595,9 @@ type ServerSkillsConfigApi serverApi // Parameters: Skill names to mark as disabled in global configuration, replacing any // previous list. func (a *ServerSkillsConfigApi) SetDisabledSkills(ctx context.Context, params *SkillsConfigSetDisabledSkillsRequest) (*SkillsConfigSetDisabledSkillsResult, error) { + if params == nil { + return nil, errors.New("params is required") + } raw, err := a.client.Request("skills.config.setDisabledSkills", params) if err != nil { return nil, err @@ -6682,8 +6724,9 @@ func NewInternalServerRpc(client *jsonrpc2.Client) *InternalServerRpc { } type sessionApi struct { - client *jsonrpc2.Client - sessionID string + client *jsonrpc2.Client + sessionID string + assertActive func() error } // Experimental: AgentApi contains experimental APIs that may change or be removed. @@ -6693,6 +6736,11 @@ type AgentApi sessionApi // // RPC method: session.agent.deselect. func (a *AgentApi) Deselect(ctx context.Context) (*SessionAgentDeselectResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.agent.deselect", req) if err != nil { @@ -6711,6 +6759,11 @@ func (a *AgentApi) Deselect(ctx context.Context) (*SessionAgentDeselectResult, e // // Returns: The currently selected custom agent, or null when using the default agent. func (a *AgentApi) GetCurrent(ctx context.Context) (*AgentGetCurrentResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.agent.getCurrent", req) if err != nil { @@ -6729,6 +6782,11 @@ func (a *AgentApi) GetCurrent(ctx context.Context) (*AgentGetCurrentResult, erro // // Returns: Custom agents available to the session. func (a *AgentApi) List(ctx context.Context) (*AgentList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.agent.list", req) if err != nil { @@ -6747,6 +6805,11 @@ func (a *AgentApi) List(ctx context.Context) (*AgentList, error) { // // Returns: Custom agents available to the session after reloading definitions from disk. func (a *AgentApi) Reload(ctx context.Context) (*AgentReloadResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.agent.reload", req) if err != nil { @@ -6767,6 +6830,14 @@ func (a *AgentApi) Reload(ctx context.Context) (*AgentReloadResult, error) { // // Returns: The newly selected custom agent. func (a *AgentApi) Select(ctx context.Context, params *AgentSelectRequest) (*AgentSelectResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["name"] = params.Name @@ -6790,6 +6861,11 @@ type AuthApi sessionApi // // Returns: Authentication status and account metadata for the session. func (a *AuthApi) GetStatus(ctx context.Context) (*SessionAuthStatus, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.auth.getStatus", req) if err != nil { @@ -6812,6 +6888,11 @@ func (a *AuthApi) GetStatus(ctx context.Context) (*SessionAuthStatus, error) { // // Returns: Indicates whether the credential update succeeded. func (a *AuthApi) SetCredentials(ctx context.Context, params *SessionSetCredentialsParams) (*SessionSetCredentialsResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Credentials != nil { @@ -6839,6 +6920,14 @@ type CommandsApi sessionApi // // Returns: Indicates whether the command was accepted into the local execution queue. func (a *CommandsApi) Enqueue(ctx context.Context, params *EnqueueCommandParams) (*EnqueueCommandResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["command"] = params.Command @@ -6862,6 +6951,14 @@ func (a *CommandsApi) Enqueue(ctx context.Context, params *EnqueueCommandParams) // // Returns: Error message produced while executing the command, if any. func (a *CommandsApi) Execute(ctx context.Context, params *ExecuteCommandParams) (*ExecuteCommandResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["args"] = params.Args @@ -6886,6 +6983,14 @@ func (a *CommandsApi) Execute(ctx context.Context, params *ExecuteCommandParams) // // Returns: Indicates whether the pending client-handled command was completed successfully. func (a *CommandsApi) HandlePendingCommand(ctx context.Context, params *CommandsHandlePendingCommandRequest) (*CommandsHandlePendingCommandResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Error != nil { @@ -6913,6 +7018,14 @@ func (a *CommandsApi) HandlePendingCommand(ctx context.Context, params *Commands // Returns: Result of invoking the slash command (text output, prompt to send to the agent, // or completion). func (a *CommandsApi) Invoke(ctx context.Context, params *CommandsInvokeRequest) (SlashCommandInvocationResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Input != nil { @@ -6944,6 +7057,11 @@ func (a *CommandsApi) List(ctx context.Context, params ...*CommandsListRequest) if len(params) > 0 { requestParams = params[0] } + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} if requestParams != nil { if requestParams.IncludeBuiltins != nil { @@ -6977,6 +7095,14 @@ func (a *CommandsApi) List(ctx context.Context, params ...*CommandsListRequest) // // Returns: Indicates whether the queued-command response was matched to a pending request. func (a *CommandsApi) RespondToQueuedCommand(ctx context.Context, params *CommandsRespondToQueuedCommandRequest) (*CommandsRespondToQueuedCommandResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID @@ -7006,6 +7132,11 @@ type EventLogApi sessionApi // Returns: Batch of session events returned by a read, with cursor and continuation // metadata. func (a *EventLogApi) Read(ctx context.Context, params *EventLogReadRequest) (*EventsReadResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.AgentScope != nil { @@ -7043,6 +7174,14 @@ func (a *EventLogApi) Read(ctx context.Context, params *EventLogReadRequest) (*E // // Returns: Opaque handle representing an event-type interest registration. func (a *EventLogApi) RegisterInterest(ctx context.Context, params *RegisterEventInterestParams) (*RegisterEventInterestResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["eventType"] = params.EventType @@ -7066,6 +7205,14 @@ func (a *EventLogApi) RegisterInterest(ctx context.Context, params *RegisterEven // // Returns: Indicates whether the operation succeeded. func (a *EventLogApi) ReleaseInterest(ctx context.Context, params *ReleaseEventInterestParams) (*EventLogReleaseInterestResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["handle"] = params.Handle @@ -7090,6 +7237,11 @@ func (a *EventLogApi) ReleaseInterest(ctx context.Context, params *ReleaseEventI // through the entire persisted history (which would happen if `read` were called without a // cursor on a long-lived session). func (a *EventLogApi) Tail(ctx context.Context) (*EventLogTailResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.eventLog.tail", req) if err != nil { @@ -7111,6 +7263,14 @@ type ExtensionsApi sessionApi // // Parameters: Source-qualified extension identifier to disable for the session. func (a *ExtensionsApi) Disable(ctx context.Context, params *ExtensionsDisableRequest) (*SessionExtensionsDisableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID @@ -7132,6 +7292,14 @@ func (a *ExtensionsApi) Disable(ctx context.Context, params *ExtensionsDisableRe // // Parameters: Source-qualified extension identifier to enable for the session. func (a *ExtensionsApi) Enable(ctx context.Context, params *ExtensionsEnableRequest) (*SessionExtensionsEnableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID @@ -7153,6 +7321,11 @@ func (a *ExtensionsApi) Enable(ctx context.Context, params *ExtensionsEnableRequ // // Returns: Extensions discovered for the session, with their current status. func (a *ExtensionsApi) List(ctx context.Context) (*ExtensionList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.extensions.list", req) if err != nil { @@ -7169,6 +7342,11 @@ func (a *ExtensionsApi) List(ctx context.Context) (*ExtensionList, error) { // // RPC method: session.extensions.reload. func (a *ExtensionsApi) Reload(ctx context.Context) (*SessionExtensionsReloadResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.extensions.reload", req) if err != nil { @@ -7192,6 +7370,11 @@ type FleetApi sessionApi // // Returns: Indicates whether fleet mode was successfully activated. func (a *FleetApi) Start(ctx context.Context, params *FleetStartRequest) (*FleetStartResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Prompt != nil { @@ -7218,6 +7401,11 @@ type HistoryApi sessionApi // // Returns: Indicates whether an in-progress manual compaction was aborted. func (a *HistoryApi) AbortManualCompaction(ctx context.Context) (*HistoryAbortManualCompactionResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.history.abortManualCompaction", req) if err != nil { @@ -7237,6 +7425,11 @@ func (a *HistoryApi) AbortManualCompaction(ctx context.Context) (*HistoryAbortMa // // Returns: Indicates whether an in-progress background compaction was cancelled. func (a *HistoryApi) CancelBackgroundCompaction(ctx context.Context) (*HistoryCancelBackgroundCompactionResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.history.cancelBackgroundCompaction", req) if err != nil { @@ -7256,6 +7449,11 @@ func (a *HistoryApi) CancelBackgroundCompaction(ctx context.Context) (*HistoryCa // Returns: Compaction outcome with the number of tokens and messages removed, summary text, // and the resulting context window breakdown. func (a *HistoryApi) Compact(ctx context.Context) (*HistoryCompactResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.history.compact", req) if err != nil { @@ -7275,6 +7473,11 @@ func (a *HistoryApi) Compact(ctx context.Context) (*HistoryCompactResult, error) // // Returns: Markdown summary of the conversation context (empty when not available). func (a *HistoryApi) SummarizeForHandoff(ctx context.Context) (*HistorySummarizeForHandoffResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.history.summarizeForHandoff", req) if err != nil { @@ -7296,6 +7499,14 @@ func (a *HistoryApi) SummarizeForHandoff(ctx context.Context) (*HistorySummarize // // Returns: Number of events that were removed by the truncation. func (a *HistoryApi) Truncate(ctx context.Context, params *HistoryTruncateRequest) (*HistoryTruncateResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["eventId"] = params.EventID @@ -7319,6 +7530,11 @@ type InstructionsApi sessionApi // // Returns: Instruction sources loaded for the session, in merge order. func (a *InstructionsApi) GetSources(ctx context.Context) (*InstructionsGetSourcesResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.instructions.getSources", req) if err != nil { @@ -7340,6 +7556,11 @@ type LspApi sessionApi // // Parameters: Parameters for (re)loading the merged LSP configuration set. func (a *LspApi) Initialize(ctx context.Context, params *LspInitializeRequest) (*SessionLspInitializeResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Force != nil { @@ -7375,6 +7596,14 @@ type McpApi sessionApi // Returns: Indicates whether an in-flight sampling execution with the given requestId was // found and cancelled. func (a *McpApi) CancelSamplingExecution(ctx context.Context, params *McpCancelSamplingExecutionParams) (*McpCancelSamplingExecutionResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID @@ -7396,6 +7625,14 @@ func (a *McpApi) CancelSamplingExecution(ctx context.Context, params *McpCancelS // // Parameters: Name of the MCP server to disable for the session. func (a *McpApi) Disable(ctx context.Context, params *McpDisableRequest) (*SessionMcpDisableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["serverName"] = params.ServerName @@ -7417,6 +7654,14 @@ func (a *McpApi) Disable(ctx context.Context, params *McpDisableRequest) (*Sessi // // Parameters: Name of the MCP server to enable for the session. func (a *McpApi) Enable(ctx context.Context, params *McpEnableRequest) (*SessionMcpEnableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["serverName"] = params.ServerName @@ -7442,6 +7687,14 @@ func (a *McpApi) Enable(ctx context.Context, params *McpEnableRequest) (*Session // Returns: Outcome of an MCP sampling execution: success result, failure error, or // cancellation. func (a *McpApi) ExecuteSampling(ctx context.Context, params *McpExecuteSamplingParams) (*McpSamplingExecutionResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["mcpRequestId"] = params.McpRequestID @@ -7466,6 +7719,11 @@ func (a *McpApi) ExecuteSampling(ctx context.Context, params *McpExecuteSampling // // Returns: MCP servers configured for the session, with their connection status. func (a *McpApi) List(ctx context.Context) (*McpServerList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.mcp.list", req) if err != nil { @@ -7482,6 +7740,11 @@ func (a *McpApi) List(ctx context.Context) (*McpServerList, error) { // // RPC method: session.mcp.reload. func (a *McpApi) Reload(ctx context.Context) (*SessionMcpReloadResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.mcp.reload", req) if err != nil { @@ -7501,6 +7764,11 @@ func (a *McpApi) Reload(ctx context.Context) (*SessionMcpReloadResult, error) { // Returns: Indicates whether the auto-managed `github` MCP server was removed (false when // nothing to remove). func (a *McpApi) RemoveGitHub(ctx context.Context) (*McpRemoveGitHubResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.mcp.removeGitHub", req) if err != nil { @@ -7523,6 +7791,14 @@ func (a *McpApi) RemoveGitHub(ctx context.Context) (*McpRemoveGitHubResult, erro // // Returns: Env-value mode recorded on the session after the update. func (a *McpApi) SetEnvValueMode(ctx context.Context, params *McpSetEnvValueModeParams) (*McpSetEnvValueModeResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["mode"] = params.Mode @@ -7551,6 +7827,14 @@ type McpOauthApi sessionApi // Returns: OAuth authorization URL the caller should open, or empty when cached tokens // already authenticated the server. func (a *McpOauthApi) Login(ctx context.Context, params *McpOauthLoginRequest) (*McpOauthLoginResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.CallbackSuccessMessage != nil { @@ -7593,6 +7877,14 @@ type MetadataApi sessionApi // Returns: Token breakdown for the session's current context window, or null if // uninitialized. func (a *MetadataApi) ContextInfo(ctx context.Context, params *MetadataContextInfoRequest) (*MetadataContextInfoResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["outputTokenLimit"] = params.OutputTokenLimit @@ -7620,6 +7912,11 @@ func (a *MetadataApi) ContextInfo(ctx context.Context, params *MetadataContextIn // Returns: Indicates whether the local session is currently processing a turn or background // continuation. func (a *MetadataApi) IsProcessing(ctx context.Context) (*MetadataIsProcessingResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.metadata.isProcessing", req) if err != nil { @@ -7644,6 +7941,14 @@ func (a *MetadataApi) IsProcessing(ctx context.Context) (*MetadataIsProcessingRe // resume, before the next agent turn fires `session.context_info_changed` events. Returns // zeros for an empty session. func (a *MetadataApi) RecomputeContextTokens(ctx context.Context, params *MetadataRecomputeContextTokensRequest) (*MetadataRecomputeContextTokensResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["modelId"] = params.ModelID @@ -7671,6 +7976,14 @@ func (a *MetadataApi) RecomputeContextTokens(ctx context.Context, params *Metada // UI) can react. Use this when the host has detected a cwd/branch/repo change outside the // session's normal lifecycle (e.g., after a shell command in interactive mode). func (a *MetadataApi) RecordContextChange(ctx context.Context, params *MetadataRecordContextChangeRequest) (*MetadataRecordContextChangeResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["context"] = params.Context @@ -7697,6 +8010,14 @@ func (a *MetadataApi) RecordContextChange(ctx context.Context, params *MetadataR // `process.chdir` and any related side-effects (file index, etc.); this method only updates // the session's own recorded path. func (a *MetadataApi) SetWorkingDirectory(ctx context.Context, params *MetadataSetWorkingDirectoryRequest) (*MetadataSetWorkingDirectoryResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["workingDirectory"] = params.WorkingDirectory @@ -7719,6 +8040,11 @@ func (a *MetadataApi) SetWorkingDirectory(ctx context.Context, params *MetadataS // // Returns: Point-in-time snapshot of slow-changing session identifier and state fields func (a *MetadataApi) Snapshot(ctx context.Context) (*SessionMetadataSnapshot, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.metadata.snapshot", req) if err != nil { @@ -7739,6 +8065,11 @@ type ModeApi sessionApi // // Returns: The session mode the agent is operating in func (a *ModeApi) Get(ctx context.Context) (*SessionMode, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.mode.get", req) if err != nil { @@ -7757,6 +8088,14 @@ func (a *ModeApi) Get(ctx context.Context) (*SessionMode, error) { // // Parameters: Agent interaction mode to apply to the session. func (a *ModeApi) Set(ctx context.Context, params *ModeSetRequest) (*SessionModeSetResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["mode"] = params.Mode @@ -7780,6 +8119,11 @@ type ModelApi sessionApi // // Returns: The currently selected model and reasoning effort for the session. func (a *ModelApi) GetCurrent(ctx context.Context) (*CurrentModel, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.model.getCurrent", req) if err != nil { @@ -7803,6 +8147,14 @@ func (a *ModelApi) GetCurrent(ctx context.Context) (*CurrentModel, error) { // `switchTo` instead when you also need to change the model. The runtime stores the effort // on the session and applies it to subsequent turns. func (a *ModelApi) SetReasoningEffort(ctx context.Context, params *ModelSetReasoningEffortRequest) (*ModelSetReasoningEffortResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["reasoningEffort"] = params.ReasoningEffort @@ -7827,6 +8179,14 @@ func (a *ModelApi) SetReasoningEffort(ctx context.Context, params *ModelSetReaso // // Returns: The model identifier active on the session after the switch. func (a *ModelApi) SwitchTo(ctx context.Context, params *ModelSwitchToRequest) (*ModelSwitchToResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.ModelCapabilities != nil { @@ -7859,6 +8219,11 @@ type NameApi sessionApi // // Returns: The session's friendly name, or null when not yet set. func (a *NameApi) Get(ctx context.Context) (*NameGetResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.name.get", req) if err != nil { @@ -7877,6 +8242,14 @@ func (a *NameApi) Get(ctx context.Context) (*NameGetResult, error) { // // Parameters: New friendly name to apply to the session. func (a *NameApi) Set(ctx context.Context, params *NameSetRequest) (*SessionNameSetResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["name"] = params.Name @@ -7902,6 +8275,14 @@ func (a *NameApi) Set(ctx context.Context, params *NameSetRequest) (*SessionName // // Returns: Indicates whether the auto-generated summary was applied as the session's name. func (a *NameApi) SetAuto(ctx context.Context, params *NameSetAutoRequest) (*NameSetAutoResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["summary"] = params.Summary @@ -7928,6 +8309,11 @@ type OptionsApi sessionApi // // Returns: Indicates whether the session options patch was applied successfully. func (a *OptionsApi) Update(ctx context.Context, params *SessionUpdateOptionsParams) (*SessionUpdateOptionsResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.AdditionalContentExclusionPolicies != nil { @@ -8062,6 +8448,11 @@ type PermissionsApi sessionApi // // Returns: Indicates whether the operation succeeded. func (a *PermissionsApi) Configure(ctx context.Context, params *PermissionsConfigureParams) (*PermissionsConfigureResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.AdditionalContentExclusionPolicies != nil { @@ -8104,6 +8495,14 @@ func (a *PermissionsApi) Configure(ctx context.Context, params *PermissionsConfi // Returns: Indicates whether the permission decision was applied; false when the request // was already resolved. func (a *PermissionsApi) HandlePendingPermissionRequest(ctx context.Context, params *PermissionDecisionRequest) (*PermissionRequestResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID @@ -8129,6 +8528,14 @@ func (a *PermissionsApi) HandlePendingPermissionRequest(ctx context.Context, par // // Returns: Indicates whether the operation succeeded. func (a *PermissionsApi) ModifyRules(ctx context.Context, params *PermissionsModifyRulesParams) (*PermissionsModifyRulesResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Add != nil { @@ -8163,6 +8570,14 @@ func (a *PermissionsApi) ModifyRules(ctx context.Context, params *PermissionsMod // // Returns: Indicates whether the operation succeeded. func (a *PermissionsApi) NotifyPromptShown(ctx context.Context, params *PermissionPromptShownNotification) (*PermissionsNotifyPromptShownResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["message"] = params.Message @@ -8185,6 +8600,11 @@ func (a *PermissionsApi) NotifyPromptShown(ctx context.Context, params *Permissi // // Returns: List of pending permission requests reconstructed from event history. func (a *PermissionsApi) PendingRequests(ctx context.Context) (*PendingPermissionRequestList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.permissions.pendingRequests", req) if err != nil { @@ -8203,6 +8623,11 @@ func (a *PermissionsApi) PendingRequests(ctx context.Context) (*PendingPermissio // // Returns: Indicates whether the operation succeeded. func (a *PermissionsApi) ResetSessionApprovals(ctx context.Context) (*PermissionsResetSessionApprovalsResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.permissions.resetSessionApprovals", req) if err != nil { @@ -8225,6 +8650,14 @@ func (a *PermissionsApi) ResetSessionApprovals(ctx context.Context) (*Permission // // Returns: Indicates whether the operation succeeded. func (a *PermissionsApi) SetApproveAll(ctx context.Context, params *PermissionsSetApproveAllRequest) (*PermissionsSetApproveAllResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["enabled"] = params.Enabled @@ -8252,6 +8685,14 @@ func (a *PermissionsApi) SetApproveAll(ctx context.Context, params *PermissionsS // // Returns: Indicates whether the operation succeeded. func (a *PermissionsApi) SetRequired(ctx context.Context, params *PermissionsSetRequiredRequest) (*PermissionsSetRequiredResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["required"] = params.Required @@ -8277,6 +8718,14 @@ type PermissionsPathsApi sessionApi // // Returns: Indicates whether the operation succeeded. func (a *PermissionsPathsApi) Add(ctx context.Context, params *PermissionPathsAddParams) (*PermissionsPathsAddResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["path"] = params.Path @@ -8301,6 +8750,14 @@ func (a *PermissionsPathsApi) Add(ctx context.Context, params *PermissionPathsAd // // Returns: Indicates whether the supplied path is within the session's allowed directories. func (a *PermissionsPathsApi) IsPathWithinAllowedDirectories(ctx context.Context, params *PermissionPathsAllowedCheckParams) (*PermissionPathsAllowedCheckResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["path"] = params.Path @@ -8325,6 +8782,14 @@ func (a *PermissionsPathsApi) IsPathWithinAllowedDirectories(ctx context.Context // // Returns: Indicates whether the supplied path is within the session's workspace directory. func (a *PermissionsPathsApi) IsPathWithinWorkspace(ctx context.Context, params *PermissionPathsWorkspaceCheckParams) (*PermissionPathsWorkspaceCheckResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["path"] = params.Path @@ -8346,6 +8811,11 @@ func (a *PermissionsPathsApi) IsPathWithinWorkspace(ctx context.Context, params // // Returns: Snapshot of the session's allow-listed directories and primary working directory. func (a *PermissionsPathsApi) List(ctx context.Context) (*PermissionPathsList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.permissions.paths.list", req) if err != nil { @@ -8367,6 +8837,14 @@ func (a *PermissionsPathsApi) List(ctx context.Context) (*PermissionPathsList, e // // Returns: Indicates whether the operation succeeded. func (a *PermissionsPathsApi) UpdatePrimary(ctx context.Context, params *PermissionPathsUpdatePrimaryParams) (*PermissionsPathsUpdatePrimaryResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["path"] = params.Path @@ -8397,6 +8875,14 @@ type PermissionsUrlsApi sessionApi // // Returns: Indicates whether the operation succeeded. func (a *PermissionsUrlsApi) SetUnrestrictedMode(ctx context.Context, params *PermissionUrlsSetUnrestrictedModeParams) (*PermissionsUrlsSetUnrestrictedModeResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["enabled"] = params.Enabled @@ -8422,6 +8908,11 @@ type PlanApi sessionApi // // RPC method: session.plan.delete. func (a *PlanApi) Delete(ctx context.Context) (*SessionPlanDeleteResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.plan.delete", req) if err != nil { @@ -8440,6 +8931,11 @@ func (a *PlanApi) Delete(ctx context.Context) (*SessionPlanDeleteResult, error) // // Returns: Existence, contents, and resolved path of the session plan file. func (a *PlanApi) Read(ctx context.Context) (*PlanReadResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.plan.read", req) if err != nil { @@ -8458,6 +8954,14 @@ func (a *PlanApi) Read(ctx context.Context) (*PlanReadResult, error) { // // Parameters: Replacement contents to write to the session plan file. func (a *PlanApi) Update(ctx context.Context, params *PlanUpdateRequest) (*SessionPlanUpdateResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["content"] = params.Content @@ -8482,6 +8986,11 @@ type PluginsApi sessionApi // // Returns: Plugins installed for the session, with their enabled state and version metadata. func (a *PluginsApi) List(ctx context.Context) (*PluginList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.plugins.list", req) if err != nil { @@ -8501,6 +9010,11 @@ type QueueApi sessionApi // // RPC method: session.queue.clear. func (a *QueueApi) Clear(ctx context.Context) (*SessionQueueClearResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.queue.clear", req) if err != nil { @@ -8520,6 +9034,11 @@ func (a *QueueApi) Clear(ctx context.Context) (*SessionQueueClearResult, error) // // Returns: Snapshot of the session's pending queued items and immediate-steering messages. func (a *QueueApi) PendingItems(ctx context.Context) (*QueuePendingItemsResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.queue.pendingItems", req) if err != nil { @@ -8538,6 +9057,11 @@ func (a *QueueApi) PendingItems(ctx context.Context) (*QueuePendingItemsResult, // // Returns: Indicates whether a user-facing pending item was removed. func (a *QueueApi) RemoveMostRecent(ctx context.Context) (*QueueRemoveMostRecentResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.queue.removeMostRecent", req) if err != nil { @@ -8557,6 +9081,11 @@ type RemoteApi sessionApi // // RPC method: session.remote.disable. func (a *RemoteApi) Disable(ctx context.Context) (*SessionRemoteDisableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.remote.disable", req) if err != nil { @@ -8579,6 +9108,11 @@ func (a *RemoteApi) Disable(ctx context.Context) (*SessionRemoteDisableResult, e // Returns: GitHub URL for the session and a flag indicating whether remote steering is // enabled. func (a *RemoteApi) Enable(ctx context.Context, params *RemoteEnableRequest) (*RemoteEnableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Mode != nil { @@ -8608,6 +9142,14 @@ func (a *RemoteApi) Enable(ctx context.Context, params *RemoteEnableRequest) (*R // Used by the host (CLI / SDK consumer) when it has just finished enabling or disabling // steering on a remote exporter that the runtime does not directly own. func (a *RemoteApi) NotifySteerableChanged(ctx context.Context, params *RemoteNotifySteerableChangedRequest) (*RemoteNotifySteerableChangedResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["remoteSteerable"] = params.RemoteSteerable @@ -8632,6 +9174,11 @@ type ScheduleApi sessionApi // // Returns: Snapshot of the currently active recurring prompts for this session. func (a *ScheduleApi) List(ctx context.Context) (*ScheduleList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.schedule.list", req) if err != nil { @@ -8653,6 +9200,14 @@ func (a *ScheduleApi) List(ctx context.Context) (*ScheduleList, error) { // Returns: Remove a scheduled prompt by id. The result entry is omitted if the id was // unknown. func (a *ScheduleApi) Stop(ctx context.Context, params *ScheduleStopRequest) (*ScheduleStopResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID @@ -8680,6 +9235,14 @@ type ShellApi sessionApi // Returns: Identifier of the spawned process, used to correlate streamed output and exit // notifications. func (a *ShellApi) Exec(ctx context.Context, params *ShellExecRequest) (*ShellExecResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["command"] = params.Command @@ -8711,6 +9274,14 @@ func (a *ShellApi) Exec(ctx context.Context, params *ShellExecRequest) (*ShellEx // Returns: Indicates whether the signal was delivered; false if the process was unknown or // already exited. func (a *ShellApi) Kill(ctx context.Context, params *ShellKillRequest) (*ShellKillResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["processId"] = params.ProcessID @@ -8738,6 +9309,14 @@ type SkillsApi sessionApi // // Parameters: Name of the skill to disable for the session. func (a *SkillsApi) Disable(ctx context.Context, params *SkillsDisableRequest) (*SessionSkillsDisableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["name"] = params.Name @@ -8759,6 +9338,14 @@ func (a *SkillsApi) Disable(ctx context.Context, params *SkillsDisableRequest) ( // // Parameters: Name of the skill to enable for the session. func (a *SkillsApi) Enable(ctx context.Context, params *SkillsEnableRequest) (*SessionSkillsEnableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["name"] = params.Name @@ -8778,6 +9365,11 @@ func (a *SkillsApi) Enable(ctx context.Context, params *SkillsEnableRequest) (*S // // RPC method: session.skills.ensureLoaded. func (a *SkillsApi) EnsureLoaded(ctx context.Context) (*SessionSkillsEnsureLoadedResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.skills.ensureLoaded", req) if err != nil { @@ -8797,6 +9389,11 @@ func (a *SkillsApi) EnsureLoaded(ctx context.Context) (*SessionSkillsEnsureLoade // Returns: Skills invoked during this session, ordered by invocation time (most recent // last). func (a *SkillsApi) GetInvoked(ctx context.Context) (*SkillsGetInvokedResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.skills.getInvoked", req) if err != nil { @@ -8815,6 +9412,11 @@ func (a *SkillsApi) GetInvoked(ctx context.Context) (*SkillsGetInvokedResult, er // // Returns: Skills available to the session, with their enabled state. func (a *SkillsApi) List(ctx context.Context) (*SkillList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.skills.list", req) if err != nil { @@ -8834,6 +9436,11 @@ func (a *SkillsApi) List(ctx context.Context) (*SkillList, error) { // Returns: Diagnostics from reloading skill definitions, with warnings and errors as // separate lists. func (a *SkillsApi) Reload(ctx context.Context) (*SkillsLoadDiagnostics, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.skills.reload", req) if err != nil { @@ -8857,6 +9464,14 @@ type TasksApi sessionApi // // Returns: Indicates whether the background task was successfully cancelled. func (a *TasksApi) Cancel(ctx context.Context, params *TasksCancelRequest) (*TasksCancelResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID @@ -8879,6 +9494,11 @@ func (a *TasksApi) Cancel(ctx context.Context, params *TasksCancelRequest) (*Tas // // Returns: The first sync-waiting task that can currently be promoted to background mode. func (a *TasksApi) GetCurrentPromotable(ctx context.Context) (*TasksGetCurrentPromotableResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.tasks.getCurrentPromotable", req) if err != nil { @@ -8899,6 +9519,14 @@ func (a *TasksApi) GetCurrentPromotable(ctx context.Context) (*TasksGetCurrentPr // // Returns: Progress information for the task, or null when no task with that ID is tracked. func (a *TasksApi) GetProgress(ctx context.Context, params *TasksGetProgressRequest) (*TasksGetProgressResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID @@ -8920,6 +9548,11 @@ func (a *TasksApi) GetProgress(ctx context.Context, params *TasksGetProgressRequ // // Returns: Background tasks currently tracked by the session. func (a *TasksApi) List(ctx context.Context) (*TaskList, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.tasks.list", req) if err != nil { @@ -8940,6 +9573,11 @@ func (a *TasksApi) List(ctx context.Context) (*TaskList, error) { // Returns: The promoted task as it now exists in background mode, omitted if no promotable // task was waiting. func (a *TasksApi) PromoteCurrentToBackground(ctx context.Context) (*TasksPromoteCurrentToBackgroundResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.tasks.promoteCurrentToBackground", req) if err != nil { @@ -8961,6 +9599,14 @@ func (a *TasksApi) PromoteCurrentToBackground(ctx context.Context) (*TasksPromot // // Returns: Indicates whether the task was successfully promoted to background mode. func (a *TasksApi) PromoteToBackground(ctx context.Context, params *TasksPromoteToBackgroundRequest) (*TasksPromoteToBackgroundResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID @@ -8983,6 +9629,11 @@ func (a *TasksApi) PromoteToBackground(ctx context.Context, params *TasksPromote // Returns: Refresh metadata for any detached background shells the runtime knows about. Use // after a long pause to pick up exit/output state for shells running outside the agent loop. func (a *TasksApi) Refresh(ctx context.Context) (*TasksRefreshResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.tasks.refresh", req) if err != nil { @@ -9004,6 +9655,14 @@ func (a *TasksApi) Refresh(ctx context.Context) (*TasksRefreshResult, error) { // Returns: Indicates whether the task was removed. False when the task does not exist or is // still running/idle. func (a *TasksApi) Remove(ctx context.Context, params *TasksRemoveRequest) (*TasksRemoveResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID @@ -9029,6 +9688,14 @@ func (a *TasksApi) Remove(ctx context.Context, params *TasksRemoveRequest) (*Tas // Returns: Indicates whether the message was delivered, with an error message when delivery // failed. func (a *TasksApi) SendMessage(ctx context.Context, params *TasksSendMessageRequest) (*TasksSendMessageResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.FromAgentID != nil { @@ -9057,6 +9724,14 @@ func (a *TasksApi) SendMessage(ctx context.Context, params *TasksSendMessageRequ // // Returns: Identifier assigned to the newly started background agent task. func (a *TasksApi) StartAgent(ctx context.Context, params *TasksStartAgentRequest) (*TasksStartAgentResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["agentType"] = params.AgentType @@ -9089,6 +9764,11 @@ func (a *TasksApi) StartAgent(ctx context.Context, params *TasksStartAgentReques // drained or after an internal timeout (default 10 minutes; configurable via // COPILOT_TASK_WAIT_TIMEOUT_SECONDS). func (a *TasksApi) WaitForPending(ctx context.Context) (*TasksWaitForPendingResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.tasks.waitForPending", req) if err != nil { @@ -9112,6 +9792,14 @@ type TelemetryApi sessionApi // Parameters: Feature override key/value pairs to attach to subsequent telemetry events // from this session. func (a *TelemetryApi) SetFeatureOverrides(ctx context.Context, params *TelemetrySetFeatureOverridesRequest) (*SessionTelemetrySetFeatureOverridesResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["features"] = params.Features @@ -9138,6 +9826,14 @@ type ToolsApi sessionApi // // Returns: Indicates whether the external tool call result was handled successfully. func (a *ToolsApi) HandlePendingToolCall(ctx context.Context, params *HandlePendingToolCallRequest) (*HandlePendingToolCallResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Error != nil { @@ -9169,6 +9865,11 @@ func (a *ToolsApi) HandlePendingToolCall(ctx context.Context, params *HandlePend // Default base-class implementation is a no-op for sessions that don't support tool // validation. func (a *ToolsApi) InitializeAndValidate(ctx context.Context) (*ToolsInitializeAndValidateResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.tools.initializeAndValidate", req) if err != nil { @@ -9192,6 +9893,14 @@ type UIApi sessionApi // // Returns: The elicitation response (accept with form values, decline, or cancel) func (a *UIApi) Elicitation(ctx context.Context, params *UIElicitationRequest) (*UIElicitationResponse, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["message"] = params.Message @@ -9218,6 +9927,14 @@ func (a *UIApi) Elicitation(ctx context.Context, params *UIElicitationRequest) ( // // Returns: Indicates whether the pending UI request was resolved by this call. func (a *UIApi) HandlePendingAutoModeSwitch(ctx context.Context, params *UIHandlePendingAutoModeSwitchRequest) (*UIHandlePendingResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID @@ -9244,6 +9961,14 @@ func (a *UIApi) HandlePendingAutoModeSwitch(ctx context.Context, params *UIHandl // Returns: Indicates whether the elicitation response was accepted; false if it was already // resolved by another client. func (a *UIApi) HandlePendingElicitation(ctx context.Context, params *UIHandlePendingElicitationRequest) (*UIElicitationResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID @@ -9270,6 +9995,14 @@ func (a *UIApi) HandlePendingElicitation(ctx context.Context, params *UIHandlePe // // Returns: Indicates whether the pending UI request was resolved by this call. func (a *UIApi) HandlePendingExitPlanMode(ctx context.Context, params *UIHandlePendingExitPlanModeRequest) (*UIHandlePendingResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID @@ -9296,6 +10029,14 @@ func (a *UIApi) HandlePendingExitPlanMode(ctx context.Context, params *UIHandleP // // Returns: Indicates whether the pending UI request was resolved by this call. func (a *UIApi) HandlePendingSampling(ctx context.Context, params *UIHandlePendingSamplingRequest) (*UIHandlePendingResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID @@ -9323,6 +10064,14 @@ func (a *UIApi) HandlePendingSampling(ctx context.Context, params *UIHandlePendi // // Returns: Indicates whether the pending UI request was resolved by this call. func (a *UIApi) HandlePendingUserInput(ctx context.Context, params *UIHandlePendingUserInputRequest) (*UIHandlePendingResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID @@ -9349,6 +10098,11 @@ func (a *UIApi) HandlePendingUserInput(ctx context.Context, params *UIHandlePend // this registration solely tells the server bridge to skip its own dispatch (so a remote // client doesn't race the in-process handler for the same requestId). func (a *UIApi) RegisterDirectAutoModeSwitchHandler(ctx context.Context) (*UIRegisterDirectAutoModeSwitchHandlerResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.ui.registerDirectAutoModeSwitchHandler", req) if err != nil { @@ -9372,6 +10126,14 @@ func (a *UIApi) RegisterDirectAutoModeSwitchHandler(ctx context.Context) (*UIReg // Returns: Indicates whether the handle was active and the registration count was // decremented. func (a *UIApi) UnregisterDirectAutoModeSwitchHandler(ctx context.Context, params *UIUnregisterDirectAutoModeSwitchHandlerRequest) (*UIUnregisterDirectAutoModeSwitchHandlerResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["handle"] = params.Handle @@ -9397,6 +10159,11 @@ type UsageApi sessionApi // Returns: Accumulated session usage metrics, including premium request cost, token counts, // model breakdown, and code-change totals. func (a *UsageApi) GetMetrics(ctx context.Context) (*UsageGetMetricsResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.usage.getMetrics", req) if err != nil { @@ -9417,6 +10184,14 @@ type WorkspacesApi sessionApi // // Parameters: Relative path and UTF-8 content for the workspace file to create or overwrite. func (a *WorkspacesApi) CreateFile(ctx context.Context, params *WorkspacesCreateFileRequest) (*SessionWorkspacesCreateFileResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["content"] = params.Content @@ -9440,6 +10215,11 @@ func (a *WorkspacesApi) CreateFile(ctx context.Context, params *WorkspacesCreate // Returns: Current workspace metadata for the session, including its absolute filesystem // path when available. func (a *WorkspacesApi) GetWorkspace(ctx context.Context) (*WorkspacesGetWorkspaceResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.workspaces.getWorkspace", req) if err != nil { @@ -9459,6 +10239,11 @@ func (a *WorkspacesApi) GetWorkspace(ctx context.Context) (*WorkspacesGetWorkspa // Returns: Workspace checkpoints in chronological order; empty when the workspace is not // enabled. func (a *WorkspacesApi) ListCheckpoints(ctx context.Context) (*WorkspacesListCheckpointsResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.workspaces.listCheckpoints", req) if err != nil { @@ -9477,6 +10262,11 @@ func (a *WorkspacesApi) ListCheckpoints(ctx context.Context) (*WorkspacesListChe // // Returns: Relative paths of files stored in the session workspace files directory. func (a *WorkspacesApi) ListFiles(ctx context.Context) (*WorkspacesListFilesResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.workspaces.listFiles", req) if err != nil { @@ -9498,6 +10288,14 @@ func (a *WorkspacesApi) ListFiles(ctx context.Context) (*WorkspacesListFilesResu // Returns: Checkpoint content as a UTF-8 string, or null when the checkpoint or workspace // is missing. func (a *WorkspacesApi) ReadCheckpoint(ctx context.Context, params *WorkspacesReadCheckpointRequest) (*WorkspacesReadCheckpointResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["number"] = params.Number @@ -9521,6 +10319,14 @@ func (a *WorkspacesApi) ReadCheckpoint(ctx context.Context, params *WorkspacesRe // // Returns: Contents of the requested workspace file as a UTF-8 string. func (a *WorkspacesApi) ReadFile(ctx context.Context, params *WorkspacesReadFileRequest) (*WorkspacesReadFileResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["path"] = params.Path @@ -9544,6 +10350,14 @@ func (a *WorkspacesApi) ReadFile(ctx context.Context, params *WorkspacesReadFile // // Returns: Descriptor for the saved paste file, or null when the workspace is unavailable. func (a *WorkspacesApi) SaveLargePaste(ctx context.Context, params *WorkspacesSaveLargePasteRequest) (*WorkspacesSaveLargePasteResult, error) { + if a.assertActive != nil { + if err := a.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.sessionID} if params != nil { req["content"] = params.Content @@ -9603,6 +10417,11 @@ type SessionRpc struct { // // Returns: Result of aborting the current turn func (a *SessionRpc) Abort(ctx context.Context, params *AbortRequest) (*AbortResult, error) { + if a.common.assertActive != nil { + if err := a.common.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.common.sessionID} if params != nil { if params.Reason != nil { @@ -9629,6 +10448,14 @@ func (a *SessionRpc) Abort(ctx context.Context, params *AbortRequest) (*AbortRes // // Returns: Identifier of the session event that was emitted for the log message. func (a *SessionRpc) Log(ctx context.Context, params *LogRequest) (*LogResult, error) { + if a.common.assertActive != nil { + if err := a.common.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.common.sessionID} if params != nil { if params.Ephemeral != nil { @@ -9667,6 +10494,14 @@ func (a *SessionRpc) Log(ctx context.Context, params *LogRequest) (*LogResult, e // // Returns: Result of sending a user message func (a *SessionRpc) Send(ctx context.Context, params *SendRequest) (*SendResult, error) { + if a.common.assertActive != nil { + if err := a.common.assertActive(); err != nil { + return nil, err + } + } + if params == nil { + return nil, errors.New("params is required") + } req := map[string]any{"sessionId": a.common.sessionID} if params != nil { if params.AgentMode != nil { @@ -9726,6 +10561,11 @@ func (a *SessionRpc) Send(ctx context.Context, params *SendRequest) (*SendResult // // Parameters: Parameters for shutting down the session func (a *SessionRpc) Shutdown(ctx context.Context, params *ShutdownRequest) (*SessionShutdownResult, error) { + if a.common.assertActive != nil { + if err := a.common.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.common.sessionID} if params != nil { if params.Reason != nil { @@ -9750,6 +10590,11 @@ func (a *SessionRpc) Shutdown(ctx context.Context, params *ShutdownRequest) (*Se // // RPC method: session.suspend. func (a *SessionRpc) Suspend(ctx context.Context) (*SessionSuspendResult, error) { + if a.common.assertActive != nil { + if err := a.common.assertActive(); err != nil { + return nil, err + } + } req := map[string]any{"sessionId": a.common.sessionID} raw, err := a.common.client.Request("session.suspend", req) if err != nil { @@ -9762,9 +10607,13 @@ func (a *SessionRpc) Suspend(ctx context.Context) (*SessionSuspendResult, error) return &result, nil } -func NewSessionRpc(client *jsonrpc2.Client, sessionID string) *SessionRpc { +func NewSessionRpc(client *jsonrpc2.Client, sessionID string, assertActive ...func() error) *SessionRpc { r := &SessionRpc{} - r.common = sessionApi{client: client, sessionID: sessionID} + var assertActiveFn func() error + if len(assertActive) > 0 { + assertActiveFn = assertActive[0] + } + r.common = sessionApi{client: client, sessionID: sessionID, assertActive: assertActiveFn} r.Agent = (*AgentApi)(&r.common) r.Auth = (*AuthApi)(&r.common) r.Commands = (*CommandsApi)(&r.common) diff --git a/go/session.go b/go/session.go index bc7e2ede9..7f826018a 100644 --- a/go/session.go +++ b/go/session.go @@ -53,6 +53,7 @@ type Session struct { SessionID string workspacePath string client *jsonrpc2.Client + owner *Client clientSessionApis *rpc.ClientSessionApiHandlers handlers []sessionHandler nextHandlerID uint64 @@ -81,7 +82,12 @@ type Session struct { // eventCh serializes user event handler dispatch. dispatchEvent enqueues; // a single goroutine (processEvents) dequeues and invokes handlers in FIFO order. eventCh chan SessionEvent - closeOnce sync.Once // guards eventCh close so Disconnect is safe to call more than once + eventMu sync.RWMutex // coordinates sends with closing eventCh + closeOnce sync.Once // guards eventCh close so Disconnect is safe to call more than once + stateMu sync.Mutex + closed bool + closing chan struct{} + closeErr error // RPC provides typed session-scoped RPC methods. RPC *rpc.SessionRpc @@ -95,22 +101,32 @@ func (s *Session) WorkspacePath() string { } // newSession creates a new session wrapper with the given session ID and client. -func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) *Session { +func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string, owner *Client) *Session { s := &Session{ SessionID: sessionID, workspacePath: workspacePath, client: client, + owner: owner, clientSessionApis: &rpc.ClientSessionApiHandlers{}, handlers: make([]sessionHandler, 0), toolHandlers: make(map[string]ToolHandler), commandHandlers: make(map[string]CommandHandler), eventCh: make(chan SessionEvent, 128), - RPC: rpc.NewSessionRpc(client, sessionID), } + s.RPC = rpc.NewSessionRpc(client, sessionID, s.assertActive) go s.processEvents() return s } +func (s *Session) assertActive() error { + s.stateMu.Lock() + defer s.stateMu.Unlock() + if s.closed { + return fmt.Errorf("session has been disconnected") + } + return nil +} + // Send sends a message to this session and waits for the response. // // The message is processed asynchronously. Subscribe to events via [Session.On] @@ -134,6 +150,9 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) // log.Printf("Failed to send message: %v", err) // } func (s *Session) Send(ctx context.Context, options MessageOptions) (string, error) { + if err := s.assertActive(); err != nil { + return "", err + } traceparent, tracestate := getTraceContext(ctx) req := sessionSendRequest{ SessionID: s.SessionID, @@ -187,6 +206,9 @@ func (s *Session) Send(ctx context.Context, options MessageOptions) (string, err // } // } func (s *Session) SendAndWait(ctx context.Context, options MessageOptions) (*SessionEvent, error) { + if err := s.assertActive(); err != nil { + return nil, err + } if _, ok := ctx.Deadline(); !ok { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, 60*time.Second) @@ -245,6 +267,7 @@ func (s *Session) SendAndWait(ctx context.Context, options MessageOptions) (*Ses // // The returned function can be called to unsubscribe the handler. It is safe // to call the unsubscribe function multiple times. +// Panics if the session has been disconnected. // // Example: // @@ -260,6 +283,10 @@ func (s *Session) SendAndWait(ctx context.Context, options MessageOptions) (*Ses // // Later, to stop receiving events: // unsubscribe() func (s *Session) On(handler SessionEventHandler) func() { + if err := s.assertActive(); err != nil { + panic(err) + } + s.handlerMutex.Lock() defer s.handlerMutex.Unlock() @@ -752,6 +779,9 @@ func (s *Session) UI() *SessionUI { // assertElicitation checks that the host supports elicitation and returns an error if not. func (s *Session) assertElicitation() error { + if err := s.assertActive(); err != nil { + return err + } caps := s.Capabilities() if caps.UI == nil || !caps.UI.Elicitation { return fmt.Errorf("elicitation is not supported by the host; check session.Capabilities().UI.Elicitation before calling UI methods") @@ -924,14 +954,17 @@ func fromRPCContent(value rpc.UIElicitationFieldValue) any { func (s *Session) dispatchEvent(event SessionEvent) { go s.handleBroadcastEvent(event) - // Send to the event channel in a closure with a recover guard. - // Disconnect closes eventCh, and in Go sending on a closed channel - // panics — there is no non-panicking send primitive. We only want - // to suppress that specific panic; other panics are not expected here. - func() { - defer func() { recover() }() - s.eventCh <- event - }() + s.eventMu.RLock() + defer s.eventMu.RUnlock() + + s.stateMu.Lock() + closed := s.closed + s.stateMu.Unlock() + if closed { + return + } + + s.eventCh <- event } // processEvents is the single consumer goroutine for the event channel. @@ -1168,6 +1201,9 @@ func rpcPermissionDecisionFromKind(kind rpc.PermissionDecisionKind) rpc.Permissi // } // } func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { + if err := s.assertActive(); err != nil { + return nil, err + } result, err := s.client.Request("session.getMessages", sessionGetMessagesRequest{SessionID: s.SessionID}) if err != nil { @@ -1204,14 +1240,52 @@ func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { // log.Printf("Failed to disconnect session: %v", err) // } func (s *Session) Disconnect() error { + s.stateMu.Lock() + if s.closed { + s.stateMu.Unlock() + return nil + } + if s.closing != nil { + closing := s.closing + s.stateMu.Unlock() + <-closing + return s.closeErr + } + s.closing = make(chan struct{}) + closing := s.closing + s.stateMu.Unlock() + + err := s.disconnectCore() + + s.stateMu.Lock() + s.closeErr = err + close(closing) + s.stateMu.Unlock() + return err +} + +func (s *Session) disconnectCore() error { + defer s.markDisconnected() _, err := s.client.Request("session.destroy", sessionDestroyRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to disconnect session: %w", err) } + return nil +} + +func (s *Session) markDisconnected() { + s.stateMu.Lock() + if s.closed { + s.stateMu.Unlock() + return + } + s.closed = true + s.stateMu.Unlock() + s.eventMu.Lock() s.closeOnce.Do(func() { close(s.eventCh) }) + s.eventMu.Unlock() - // Clear handlers s.handlerMutex.Lock() s.handlers = nil s.handlerMutex.Unlock() @@ -1232,7 +1306,29 @@ func (s *Session) Disconnect() error { s.elicitationHandler = nil s.elicitationMu.Unlock() - return nil + s.userInputMux.Lock() + s.userInputHandler = nil + s.userInputMux.Unlock() + + s.exitPlanModeMu.Lock() + s.exitPlanModeHandler = nil + s.exitPlanModeMu.Unlock() + + s.autoModeSwitchMu.Lock() + s.autoModeSwitchHandler = nil + s.autoModeSwitchMu.Unlock() + + s.hooksMux.Lock() + s.hooks = nil + s.hooksMux.Unlock() + + s.transformMu.Lock() + s.transformCallbacks = nil + s.transformMu.Unlock() + + if s.owner != nil { + s.owner.unregisterSession(s) + } } // Deprecated: Use [Session.Disconnect] instead. Destroy will be removed in a future release. @@ -1265,6 +1361,9 @@ func (s *Session) Destroy() error { // log.Printf("Failed to abort: %v", err) // } func (s *Session) Abort(ctx context.Context) error { + if err := s.assertActive(); err != nil { + return err + } _, err := s.client.Request("session.abort", sessionAbortRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to abort session: %w", err) @@ -1294,6 +1393,9 @@ type SetModelOptions struct { // log.Printf("Failed to set model: %v", err) // } func (s *Session) SetModel(ctx context.Context, model string, opts *SetModelOptions) error { + if err := s.assertActive(); err != nil { + return err + } params := &rpc.ModelSwitchToRequest{ModelID: model} if opts != nil { params.ReasoningEffort = opts.ReasoningEffort @@ -1334,6 +1436,9 @@ type LogOptions struct { // // Ephemeral message (not persisted) // session.Log(ctx, "Working...", &copilot.LogOptions{Ephemeral: copilot.Bool(true)}) func (s *Session) Log(ctx context.Context, message string, opts *LogOptions) error { + if err := s.assertActive(); err != nil { + return err + } params := &rpc.LogRequest{Message: message} if opts != nil { diff --git a/go/session_test.go b/go/session_test.go index 0b7de5ac9..27dc20d8d 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -46,6 +46,55 @@ func TestRPCPermissionDecisionFromKindPreservesUnknownKind(t *testing.T) { } } +func TestSessionLifecycleCleanup(t *testing.T) { + client := NewClient(nil) + session := newSession("session-1", nil, "", client) + + if err := client.registerSession(session); err != nil { + t.Fatalf("register session: %v", err) + } + + session.markDisconnected() + + if _, ok := client.sessions["session-1"]; ok { + t.Fatal("expected disconnected session to unregister from client") + } + if err := session.assertActive(); err == nil { + t.Fatal("expected disconnected session to reject further use") + } +} + +func TestSessionLifecycleCleanupDoesNotRemoveReplacement(t *testing.T) { + client := NewClient(nil) + stale := newSession("session-1", nil, "", client) + replacement := newSession("session-1", nil, "", client) + client.sessions["session-1"] = replacement + + stale.markDisconnected() + + if got := client.sessions["session-1"]; got != replacement { + t.Fatalf("expected replacement session to remain registered, got %#v", got) + } + + replacement.markDisconnected() +} + +func TestClientRegisterSessionRejectsDuplicateActiveSession(t *testing.T) { + client := NewClient(nil) + first := newSession("session-1", nil, "", client) + second := newSession("session-1", nil, "", client) + + if err := client.registerSession(first); err != nil { + t.Fatalf("register first session: %v", err) + } + if err := client.registerSession(second); err == nil || !strings.Contains(err.Error(), "already active") { + t.Fatalf("expected duplicate active session error, got %v", err) + } + + first.markDisconnected() + second.markDisconnected() +} + func TestSession_On(t *testing.T) { t.Run("multiple handlers all receive events", func(t *testing.T) { session, cleanup := newTestSession() @@ -231,6 +280,23 @@ func TestSession_On(t *testing.T) { t.Errorf("Expected 2 events dispatched, got %d", eventCount.Load()) } }) + + t.Run("panics after disconnect", func(t *testing.T) { + session := newSession("session-1", nil, "", nil) + session.markDisconnected() + + defer func() { + r := recover() + if r == nil { + t.Fatal("expected On to panic after disconnect") + } + if !strings.Contains(fmt.Sprint(r), "session has been disconnected") { + t.Fatalf("expected disconnected panic, got %v", r) + } + }() + + session.On(func(SessionEvent) {}) + }) } func TestSession_CommandRouting(t *testing.T) { diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 6342b6667..af77b1dfb 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -450,6 +450,20 @@ export class CopilotClient { } } + private registerSession(session: CopilotSession): void { + const existing = this.sessions.get(session.sessionId); + if (existing && existing !== session) { + throw new Error(`Session ${session.sessionId} is already active.`); + } + this.sessions.set(session.sessionId, session); + } + + private unregisterSession(session: CopilotSession): void { + if (this.sessions.get(session.sessionId) === session) { + this.sessions.delete(session.sessionId); + } + } + private setupSessionFs( session: CopilotSession, config: { createSessionFsHandler?: (session: CopilotSession) => SessionFsProvider } @@ -551,37 +565,25 @@ export class CopilotClient { async stop(): Promise { const errors: Error[] = []; - // Disconnect all active sessions with retry logic - for (const session of this.sessions.values()) { + // Disconnect a stable snapshot so per-session cleanup can mutate the map safely. + for (const session of [...this.sessions.values()]) { const sessionId = session.sessionId; - let lastError: Error | null = null; - - // Try up to 3 times with exponential backoff - for (let attempt = 1; attempt <= 3; attempt++) { - try { - await session.disconnect(); - lastError = null; - break; // Success - } catch (error) { - lastError = error instanceof Error ? error : new Error(String(error)); - - if (attempt < 3) { - // Exponential backoff: 100ms, 200ms - const delay = 100 * Math.pow(2, attempt - 1); - await new Promise((resolve) => setTimeout(resolve, delay)); - } - } - } - - if (lastError) { + try { + await session.disconnect(); + } catch (error) { + const disconnectError = error instanceof Error ? error : new Error(String(error)); errors.push( new Error( - `Failed to disconnect session ${sessionId} after 3 attempts: ${lastError.message}` + `Failed to disconnect session ${sessionId}: ${disconnectError.message}` ) ); } } + const remainingSessions = [...this.sessions.values()]; this.sessions.clear(); + for (const session of remainingSessions) { + session._markDisconnected(); + } // Close connection if (this.connection) { @@ -596,6 +598,7 @@ export class CopilotClient { } this.connection = null; this._rpc = null; + this._internalRpc = null; } // Clear models cache @@ -668,6 +671,10 @@ export class CopilotClient { async forceStop(): Promise { this.forceStopping = true; + for (const session of this.sessions.values()) { + session._markDisconnected(); + } + // Clear sessions immediately without trying to destroy them this.sessions.clear(); @@ -680,6 +687,7 @@ export class CopilotClient { } this.connection = null; this._rpc = null; + this._internalRpc = null; } // Clear models cache @@ -761,7 +769,8 @@ export class CopilotClient { sessionId, this.connection!, undefined, - this.onGetTraceContext + this.onGetTraceContext, + (session) => this.unregisterSession(session) ); session.registerTools(config.tools); session.registerCommands(config.commands); @@ -793,10 +802,10 @@ export class CopilotClient { if (config.onEvent) { session.on(config.onEvent); } - this.sessions.set(sessionId, session); - this.setupSessionFs(session, config); + this.registerSession(session); try { + this.setupSessionFs(session, config); const response = await this.connection!.sendRequest("session.create", { ...(await getTraceContext(this.onGetTraceContext)), model: config.model, @@ -853,7 +862,8 @@ export class CopilotClient { session["_workspacePath"] = workspacePath; session.setCapabilities(capabilities); } catch (e) { - this.sessions.delete(sessionId); + this.unregisterSession(session); + session._markDisconnected(); throw e; } @@ -899,7 +909,8 @@ export class CopilotClient { sessionId, this.connection!, undefined, - this.onGetTraceContext + this.onGetTraceContext, + (session) => this.unregisterSession(session) ); session.registerTools(config.tools); session.registerCommands(config.commands); @@ -931,10 +942,10 @@ export class CopilotClient { if (config.onEvent) { session.on(config.onEvent); } - this.sessions.set(sessionId, session); - this.setupSessionFs(session, config); + this.registerSession(session); try { + this.setupSessionFs(session, config); const response = await this.connection!.sendRequest("session.resume", { ...(await getTraceContext(this.onGetTraceContext)), sessionId, @@ -993,7 +1004,8 @@ export class CopilotClient { session["_workspacePath"] = workspacePath; session.setCapabilities(capabilities); } catch (e) { - this.sessions.delete(sessionId); + this.unregisterSession(session); + session._markDisconnected(); throw e; } @@ -1245,7 +1257,12 @@ export class CopilotClient { } // Remove from local sessions map if present - this.sessions.delete(sessionId); + const session = this.sessions.get(sessionId); + if (session) { + session._markDisconnected(); + } else { + this.sessions.delete(sessionId); + } } /** diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index 20a4d6afe..29a1febce 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -7808,8 +7808,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Server liveness response, including the echoed message, current server timestamp, and protocol version. */ - ping: async (params: PingRequest): Promise => - connection.sendRequest("ping", params), + ping: async (params?: PingRequest): Promise => { + return connection.sendRequest("ping", (params ?? {})); + }, models: { /** * Lists Copilot models available to the authenticated user. @@ -7818,8 +7819,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns List of Copilot models available to the resolved user, including capabilities and billing metadata. */ - list: async (params: ModelsListRequest): Promise => - connection.sendRequest("models.list", params), + list: async (params?: ModelsListRequest): Promise => { + return connection.sendRequest("models.list", (params ?? {})); + }, }, tools: { /** @@ -7829,8 +7831,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Built-in tools available for the requested model, with their parameters and instructions. */ - list: async (params: ToolsListRequest): Promise => - connection.sendRequest("tools.list", params), + list: async (params?: ToolsListRequest): Promise => { + return connection.sendRequest("tools.list", (params ?? {})); + }, }, account: { /** @@ -7840,8 +7843,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Quota usage snapshots for the resolved user, keyed by quota type. */ - getQuota: async (params: AccountGetQuotaRequest): Promise => - connection.sendRequest("account.getQuota", params), + getQuota: async (params?: AccountGetQuotaRequest): Promise => { + return connection.sendRequest("account.getQuota", (params ?? {})); + }, }, mcp: { config: { @@ -7850,43 +7854,64 @@ export function createServerRpc(connection: MessageConnection) { * * @returns User-configured MCP servers, keyed by server name. */ - list: async (): Promise => - connection.sendRequest("mcp.config.list", {}), + list: async (): Promise => { + return connection.sendRequest("mcp.config.list", {}); + }, /** * Adds an MCP server to user configuration. * * @param params MCP server name and configuration to add to user configuration. */ - add: async (params: McpConfigAddRequest): Promise => - connection.sendRequest("mcp.config.add", params), + add: async (params: McpConfigAddRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("mcp.config.add", params); + }, /** * Updates an MCP server in user configuration. * * @param params MCP server name and replacement configuration to write to user configuration. */ - update: async (params: McpConfigUpdateRequest): Promise => - connection.sendRequest("mcp.config.update", params), + update: async (params: McpConfigUpdateRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("mcp.config.update", params); + }, /** * Removes an MCP server from user configuration. * * @param params MCP server name to remove from user configuration. */ - remove: async (params: McpConfigRemoveRequest): Promise => - connection.sendRequest("mcp.config.remove", params), + remove: async (params: McpConfigRemoveRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("mcp.config.remove", params); + }, /** * Enables MCP servers in user configuration for new sessions. * * @param params MCP server names to enable for new sessions. */ - enable: async (params: McpConfigEnableRequest): Promise => - connection.sendRequest("mcp.config.enable", params), + enable: async (params: McpConfigEnableRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("mcp.config.enable", params); + }, /** * Disables MCP servers in user configuration for new sessions. * * @param params MCP server names to disable for new sessions. */ - disable: async (params: McpConfigDisableRequest): Promise => - connection.sendRequest("mcp.config.disable", params), + disable: async (params: McpConfigDisableRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("mcp.config.disable", params); + }, }, /** * Discovers MCP servers from user, workspace, plugin, and builtin sources. @@ -7895,8 +7920,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns MCP servers discovered from user, workspace, plugin, and built-in sources. */ - discover: async (params: McpDiscoverRequest): Promise => - connection.sendRequest("mcp.discover", params), + discover: async (params?: McpDiscoverRequest): Promise => { + return connection.sendRequest("mcp.discover", (params ?? {})); + }, }, skills: { config: { @@ -7905,8 +7931,12 @@ export function createServerRpc(connection: MessageConnection) { * * @param params Skill names to mark as disabled in global configuration, replacing any previous list. */ - setDisabledSkills: async (params: SkillsConfigSetDisabledSkillsRequest): Promise => - connection.sendRequest("skills.config.setDisabledSkills", params), + setDisabledSkills: async (params: SkillsConfigSetDisabledSkillsRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("skills.config.setDisabledSkills", params); + }, }, /** * Discovers skills across global and project sources. @@ -7915,8 +7945,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Skills discovered across global and project sources. */ - discover: async (params: SkillsDiscoverRequest): Promise => - connection.sendRequest("skills.discover", params), + discover: async (params?: SkillsDiscoverRequest): Promise => { + return connection.sendRequest("skills.discover", (params ?? {})); + }, }, sessionFs: { /** @@ -7926,8 +7957,12 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Indicates whether the calling client was registered as the session filesystem provider. */ - setProvider: async (params: SessionFsSetProviderRequest): Promise => - connection.sendRequest("sessionFs.setProvider", params), + setProvider: async (params: SessionFsSetProviderRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("sessionFs.setProvider", params); + }, }, /** @experimental */ sessions: { @@ -7938,8 +7973,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Identifier and optional friendly name assigned to the newly forked session. */ - fork: async (params: SessionsForkRequest): Promise => - connection.sendRequest("sessions.fork", params), + fork: async (params?: SessionsForkRequest): Promise => { + return connection.sendRequest("sessions.fork", (params ?? {})); + }, /** * Connects to an existing remote session and exposes it as an SDK session. * @@ -7947,8 +7983,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Remote session connection result. */ - connect: async (params: ConnectRemoteSessionParams): Promise => - connection.sendRequest("sessions.connect", params), + connect: async (params?: ConnectRemoteSessionParams): Promise => { + return connection.sendRequest("sessions.connect", (params ?? {})); + }, /** * Lists persisted sessions, optionally filtered by working-directory context. * @@ -7956,8 +7993,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Persisted sessions matching the filter, ordered most-recently-modified first. */ - list: async (params: SessionsListRequest): Promise => - connection.sendRequest("sessions.list", params), + list: async (params?: SessionsListRequest): Promise => { + return connection.sendRequest("sessions.list", (params ?? {})); + }, /** * Finds the local session bound to a GitHub task ID, if any. * @@ -7965,8 +8003,12 @@ export function createServerRpc(connection: MessageConnection) { * * @returns ID of the local session bound to the given GitHub task, or omitted when none. */ - findByTaskId: async (params: SessionsFindByTaskIDRequest): Promise => - connection.sendRequest("sessions.findByTaskId", params), + findByTaskId: async (params: SessionsFindByTaskIDRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("sessions.findByTaskId", params); + }, /** * Resolves a UUID prefix to a unique session ID, if exactly one session matches. * @@ -7974,8 +8016,12 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Session ID matching the prefix, omitted when no unique match exists. */ - findByPrefix: async (params: SessionsFindByPrefixRequest): Promise => - connection.sendRequest("sessions.findByPrefix", params), + findByPrefix: async (params: SessionsFindByPrefixRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("sessions.findByPrefix", params); + }, /** * Returns the most-relevant prior session for a given working-directory context. * @@ -7983,8 +8029,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Most-relevant session ID for the supplied context, or omitted when no sessions exist. */ - getLastForContext: async (params: SessionsGetLastForContextRequest): Promise => - connection.sendRequest("sessions.getLastForContext", params), + getLastForContext: async (params?: SessionsGetLastForContextRequest): Promise => { + return connection.sendRequest("sessions.getLastForContext", (params ?? {})); + }, /** * Computes the absolute path to a session's persisted events.jsonl file. * @@ -7992,15 +8039,17 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Absolute path to the session's events.jsonl file on disk. */ - getEventFilePath: async (params: SessionsGetEventFilePathRequest): Promise => - connection.sendRequest("sessions.getEventFilePath", params), + getEventFilePath: async (params?: SessionsGetEventFilePathRequest): Promise => { + return connection.sendRequest("sessions.getEventFilePath", (params ?? {})); + }, /** * Returns the on-disk byte size of each session's workspace directory. * * @returns Map of sessionId -> on-disk size in bytes for each session's workspace directory. */ - getSizes: async (): Promise => - connection.sendRequest("sessions.getSizes", {}), + getSizes: async (): Promise => { + return connection.sendRequest("sessions.getSizes", {}); + }, /** * Returns the subset of the supplied session IDs that are currently held by another running process. * @@ -8008,8 +8057,12 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Session IDs from the input set that are currently in use by another process. */ - checkInUse: async (params: SessionsCheckInUseRequest): Promise => - connection.sendRequest("sessions.checkInUse", params), + checkInUse: async (params: SessionsCheckInUseRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("sessions.checkInUse", params); + }, /** * Returns a session's persisted remote-steerable flag, if any has been recorded. * @@ -8017,8 +8070,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns The session's persisted remote-steerable flag, or omitted when no value has been persisted. */ - getPersistedRemoteSteerable: async (params: SessionsGetPersistedRemoteSteerableRequest): Promise => - connection.sendRequest("sessions.getPersistedRemoteSteerable", params), + getPersistedRemoteSteerable: async (params?: SessionsGetPersistedRemoteSteerableRequest): Promise => { + return connection.sendRequest("sessions.getPersistedRemoteSteerable", (params ?? {})); + }, /** * Closes a session: emits shutdown, flushes pending events, releases the in-use lock, and disposes the active session. * @@ -8026,8 +8080,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Closes a session: emits shutdown, flushes pending events to disk, releases the in-use lock, disposes the active session. Idempotent: succeeds even if the session is not currently active. */ - close: async (params: SessionsCloseRequest): Promise => - connection.sendRequest("sessions.close", params), + close: async (params?: SessionsCloseRequest): Promise => { + return connection.sendRequest("sessions.close", (params ?? {})); + }, /** * Closes, deactivates, and deletes a set of sessions, returning the bytes freed per session. * @@ -8035,8 +8090,12 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Map of sessionId -> bytes freed by removing the session's workspace directory. */ - bulkDelete: async (params: SessionsBulkDeleteRequest): Promise => - connection.sendRequest("sessions.bulkDelete", params), + bulkDelete: async (params: SessionsBulkDeleteRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("sessions.bulkDelete", params); + }, /** * Deletes sessions older than the given threshold, with optional dry-run and exclusion list. * @@ -8044,8 +8103,12 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Outcome of the prune operation: deleted IDs, dry-run candidates, skipped IDs, total bytes freed, and the dry-run flag. */ - pruneOld: async (params: SessionsPruneOldRequest): Promise => - connection.sendRequest("sessions.pruneOld", params), + pruneOld: async (params: SessionsPruneOldRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("sessions.pruneOld", params); + }, /** * Flushes a session's pending events to disk. * @@ -8053,8 +8116,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Flush a session's pending events to disk. No-op when no writer exists for the session (e.g., already closed). */ - save: async (params: SessionsSaveRequest): Promise => - connection.sendRequest("sessions.save", params), + save: async (params?: SessionsSaveRequest): Promise => { + return connection.sendRequest("sessions.save", (params ?? {})); + }, /** * Releases the in-use lock held by this process for a session. * @@ -8062,8 +8126,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Release the in-use lock held by this process for the given session. No-op when this process does not currently hold a lock for the session. */ - releaseLock: async (params: SessionsReleaseLockRequest): Promise => - connection.sendRequest("sessions.releaseLock", params), + releaseLock: async (params?: SessionsReleaseLockRequest): Promise => { + return connection.sendRequest("sessions.releaseLock", (params ?? {})); + }, /** * Backfills missing summary and context fields on the supplied session metadata records. * @@ -8071,8 +8136,12 @@ export function createServerRpc(connection: MessageConnection) { * * @returns The same metadata records, with summary and context fields backfilled where available. */ - enrichMetadata: async (params: SessionsEnrichMetadataRequest): Promise => - connection.sendRequest("sessions.enrichMetadata", params), + enrichMetadata: async (params: SessionsEnrichMetadataRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("sessions.enrichMetadata", params); + }, /** * Reloads user, plugin, and (optionally) repo hooks on the active session. * @@ -8080,8 +8149,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Reload all hooks (user, plugin, optionally repo) and apply them to the active session. Call after installing or removing plugins so their hooks take effect immediately. No-op when no active session matches the given sessionId. */ - reloadPluginHooks: async (params: SessionsReloadPluginHooksRequest): Promise => - connection.sendRequest("sessions.reloadPluginHooks", params), + reloadPluginHooks: async (params?: SessionsReloadPluginHooksRequest): Promise => { + return connection.sendRequest("sessions.reloadPluginHooks", (params ?? {})); + }, /** * Loads previously-deferred repo-level hooks on the active session, returning queued startup prompts. * @@ -8089,8 +8159,9 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Queued repo-level startup prompts and the total hook command count after loading. */ - loadDeferredRepoHooks: async (params: SessionsLoadDeferredRepoHooksRequest): Promise => - connection.sendRequest("sessions.loadDeferredRepoHooks", params), + loadDeferredRepoHooks: async (params?: SessionsLoadDeferredRepoHooksRequest): Promise => { + return connection.sendRequest("sessions.loadDeferredRepoHooks", (params ?? {})); + }, /** * Replaces the manager-wide additional plugins registered with the session manager. * @@ -8098,8 +8169,12 @@ export function createServerRpc(connection: MessageConnection) { * * @returns Replace the manager-wide additional plugins. New session creations and subsequent hook reloads see the new set; already-running sessions keep their existing hook installation until the next reload. */ - setAdditionalPlugins: async (params: SessionsSetAdditionalPluginsRequest): Promise => - connection.sendRequest("sessions.setAdditionalPlugins", params), + setAdditionalPlugins: async (params: SessionsSetAdditionalPluginsRequest): Promise => { + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("sessions.setAdditionalPlugins", params); + }, }, }; } @@ -8118,19 +8193,22 @@ export function createInternalServerRpc(connection: MessageConnection) { * * @returns Handshake result reporting the server's protocol version and package version on success. */ - connect: async (params: ConnectRequest): Promise => - connection.sendRequest("connect", params), + connect: async (params?: ConnectRequest): Promise => { + return connection.sendRequest("connect", (params ?? {})); + }, }; } /** Create typed session-scoped RPC methods. */ -export function createSessionRpc(connection: MessageConnection, sessionId: string) { +export function createSessionRpc(connection: MessageConnection, sessionId: string, assertActive?: () => void) { return { /** * Suspends the session while preserving persisted state for later resume. */ - suspend: async (): Promise => - connection.sendRequest("session.suspend", { sessionId }), + suspend: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.suspend", { sessionId }); + }, /** * Sends a user message to the session and returns its message ID. * @@ -8138,8 +8216,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Result of sending a user message */ - send: async (params: SendRequest): Promise => - connection.sendRequest("session.send", { sessionId, ...params }), + send: async (params: SendRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.send", { sessionId, ...params }); + }, /** * Aborts the current agent turn. * @@ -8147,23 +8230,29 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Result of aborting the current turn */ - abort: async (params: AbortRequest): Promise => - connection.sendRequest("session.abort", { sessionId, ...params }), + abort: async (params?: AbortRequest): Promise => { + assertActive?.(); + return connection.sendRequest("session.abort", { sessionId, ...params }); + }, /** * Shuts down the session and persists its final state. Awaits any deferred sessionEnd hooks before resolving so user-supplied hook scripts complete before the runtime tears down. * * @param params Parameters for shutting down the session */ - shutdown: async (params: ShutdownRequest): Promise => - connection.sendRequest("session.shutdown", { sessionId, ...params }), + shutdown: async (params?: ShutdownRequest): Promise => { + assertActive?.(); + return connection.sendRequest("session.shutdown", { sessionId, ...params }); + }, auth: { /** * Gets authentication status and account metadata for the session. * * @returns Authentication status and account metadata for the session. */ - getStatus: async (): Promise => - connection.sendRequest("session.auth.getStatus", { sessionId }), + getStatus: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.auth.getStatus", { sessionId }); + }, /** * Updates the session's auth credentials used for outbound model and API requests. * @@ -8171,8 +8260,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the credential update succeeded. */ - setCredentials: async (params: SessionSetCredentialsParams): Promise => - connection.sendRequest("session.auth.setCredentials", { sessionId, ...params }), + setCredentials: async (params?: SessionSetCredentialsParams): Promise => { + assertActive?.(); + return connection.sendRequest("session.auth.setCredentials", { sessionId, ...params }); + }, }, model: { /** @@ -8180,8 +8271,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns The currently selected model and reasoning effort for the session. */ - getCurrent: async (): Promise => - connection.sendRequest("session.model.getCurrent", { sessionId }), + getCurrent: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.model.getCurrent", { sessionId }); + }, /** * Switches the session to a model and optional reasoning configuration. * @@ -8189,8 +8282,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns The model identifier active on the session after the switch. */ - switchTo: async (params: ModelSwitchToRequest): Promise => - connection.sendRequest("session.model.switchTo", { sessionId, ...params }), + switchTo: async (params: ModelSwitchToRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.model.switchTo", { sessionId, ...params }); + }, /** * Updates the session's reasoning effort without changing the selected model. * @@ -8198,8 +8296,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Update the session's reasoning effort without changing the selected model. Use `switchTo` instead when you also need to change the model. The runtime stores the effort on the session and applies it to subsequent turns. */ - setReasoningEffort: async (params: ModelSetReasoningEffortRequest): Promise => - connection.sendRequest("session.model.setReasoningEffort", { sessionId, ...params }), + setReasoningEffort: async (params: ModelSetReasoningEffortRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.model.setReasoningEffort", { sessionId, ...params }); + }, }, mode: { /** @@ -8207,15 +8310,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns The session mode the agent is operating in */ - get: async (): Promise => - connection.sendRequest("session.mode.get", { sessionId }), + get: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.mode.get", { sessionId }); + }, /** * Sets the current agent interaction mode. * * @param params Agent interaction mode to apply to the session. */ - set: async (params: ModeSetRequest): Promise => - connection.sendRequest("session.mode.set", { sessionId, ...params }), + set: async (params: ModeSetRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.mode.set", { sessionId, ...params }); + }, }, name: { /** @@ -8223,15 +8333,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns The session's friendly name, or null when not yet set. */ - get: async (): Promise => - connection.sendRequest("session.name.get", { sessionId }), + get: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.name.get", { sessionId }); + }, /** * Sets the session's friendly name. * * @param params New friendly name to apply to the session. */ - set: async (params: NameSetRequest): Promise => - connection.sendRequest("session.name.set", { sessionId, ...params }), + set: async (params: NameSetRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.name.set", { sessionId, ...params }); + }, /** * Persists an auto-generated session summary as the session's name when no user-set name exists. * @@ -8239,8 +8356,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the auto-generated summary was applied as the session's name. */ - setAuto: async (params: NameSetAutoRequest): Promise => - connection.sendRequest("session.name.setAuto", { sessionId, ...params }), + setAuto: async (params: NameSetAutoRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.name.setAuto", { sessionId, ...params }); + }, }, plan: { /** @@ -8248,20 +8370,29 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Existence, contents, and resolved path of the session plan file. */ - read: async (): Promise => - connection.sendRequest("session.plan.read", { sessionId }), + read: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.plan.read", { sessionId }); + }, /** * Writes new content to the session plan file. * * @param params Replacement contents to write to the session plan file. */ - update: async (params: PlanUpdateRequest): Promise => - connection.sendRequest("session.plan.update", { sessionId, ...params }), + update: async (params: PlanUpdateRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.plan.update", { sessionId, ...params }); + }, /** * Deletes the session plan file from the workspace. */ - delete: async (): Promise => - connection.sendRequest("session.plan.delete", { sessionId }), + delete: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.plan.delete", { sessionId }); + }, }, workspaces: { /** @@ -8269,15 +8400,19 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Current workspace metadata for the session, including its absolute filesystem path when available. */ - getWorkspace: async (): Promise => - connection.sendRequest("session.workspaces.getWorkspace", { sessionId }), + getWorkspace: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.workspaces.getWorkspace", { sessionId }); + }, /** * Lists files stored in the session workspace files directory. * * @returns Relative paths of files stored in the session workspace files directory. */ - listFiles: async (): Promise => - connection.sendRequest("session.workspaces.listFiles", { sessionId }), + listFiles: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.workspaces.listFiles", { sessionId }); + }, /** * Reads a file from the session workspace files directory. * @@ -8285,22 +8420,34 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Contents of the requested workspace file as a UTF-8 string. */ - readFile: async (params: WorkspacesReadFileRequest): Promise => - connection.sendRequest("session.workspaces.readFile", { sessionId, ...params }), + readFile: async (params: WorkspacesReadFileRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.workspaces.readFile", { sessionId, ...params }); + }, /** * Creates or overwrites a file in the session workspace files directory. * * @param params Relative path and UTF-8 content for the workspace file to create or overwrite. */ - createFile: async (params: WorkspacesCreateFileRequest): Promise => - connection.sendRequest("session.workspaces.createFile", { sessionId, ...params }), + createFile: async (params: WorkspacesCreateFileRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.workspaces.createFile", { sessionId, ...params }); + }, /** * Lists workspace checkpoints in chronological order. * * @returns Workspace checkpoints in chronological order; empty when the workspace is not enabled. */ - listCheckpoints: async (): Promise => - connection.sendRequest("session.workspaces.listCheckpoints", { sessionId }), + listCheckpoints: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.workspaces.listCheckpoints", { sessionId }); + }, /** * Reads the content of a workspace checkpoint by number. * @@ -8308,8 +8455,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Checkpoint content as a UTF-8 string, or null when the checkpoint or workspace is missing. */ - readCheckpoint: async (params: WorkspacesReadCheckpointRequest): Promise => - connection.sendRequest("session.workspaces.readCheckpoint", { sessionId, ...params }), + readCheckpoint: async (params: WorkspacesReadCheckpointRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.workspaces.readCheckpoint", { sessionId, ...params }); + }, /** * Saves pasted content as a UTF-8 file in the session workspace. * @@ -8317,8 +8469,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Descriptor for the saved paste file, or null when the workspace is unavailable. */ - saveLargePaste: async (params: WorkspacesSaveLargePasteRequest): Promise => - connection.sendRequest("session.workspaces.saveLargePaste", { sessionId, ...params }), + saveLargePaste: async (params: WorkspacesSaveLargePasteRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.workspaces.saveLargePaste", { sessionId, ...params }); + }, }, instructions: { /** @@ -8326,8 +8483,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Instruction sources loaded for the session, in merge order. */ - getSources: async (): Promise => - connection.sendRequest("session.instructions.getSources", { sessionId }), + getSources: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.instructions.getSources", { sessionId }); + }, }, /** @experimental */ fleet: { @@ -8338,8 +8497,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether fleet mode was successfully activated. */ - start: async (params: FleetStartRequest): Promise => - connection.sendRequest("session.fleet.start", { sessionId, ...params }), + start: async (params?: FleetStartRequest): Promise => { + assertActive?.(); + return connection.sendRequest("session.fleet.start", { sessionId, ...params }); + }, }, /** @experimental */ agent: { @@ -8348,15 +8509,19 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Custom agents available to the session. */ - list: async (): Promise => - connection.sendRequest("session.agent.list", { sessionId }), + list: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.agent.list", { sessionId }); + }, /** * Gets the currently selected custom agent for the session. * * @returns The currently selected custom agent, or null when using the default agent. */ - getCurrent: async (): Promise => - connection.sendRequest("session.agent.getCurrent", { sessionId }), + getCurrent: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.agent.getCurrent", { sessionId }); + }, /** * Selects a custom agent for subsequent turns in the session. * @@ -8364,20 +8529,29 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns The newly selected custom agent. */ - select: async (params: AgentSelectRequest): Promise => - connection.sendRequest("session.agent.select", { sessionId, ...params }), + select: async (params: AgentSelectRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.agent.select", { sessionId, ...params }); + }, /** * Clears the selected custom agent and returns the session to the default agent. */ - deselect: async (): Promise => - connection.sendRequest("session.agent.deselect", { sessionId }), + deselect: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.agent.deselect", { sessionId }); + }, /** * Reloads custom agent definitions and returns the refreshed list. * * @returns Custom agents available to the session after reloading definitions from disk. */ - reload: async (): Promise => - connection.sendRequest("session.agent.reload", { sessionId }), + reload: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.agent.reload", { sessionId }); + }, }, /** @experimental */ tasks: { @@ -8388,29 +8562,40 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Identifier assigned to the newly started background agent task. */ - startAgent: async (params: TasksStartAgentRequest): Promise => - connection.sendRequest("session.tasks.startAgent", { sessionId, ...params }), + startAgent: async (params: TasksStartAgentRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.tasks.startAgent", { sessionId, ...params }); + }, /** * Lists background tasks tracked by the session. * * @returns Background tasks currently tracked by the session. */ - list: async (): Promise => - connection.sendRequest("session.tasks.list", { sessionId }), + list: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.tasks.list", { sessionId }); + }, /** * Refreshes metadata for any detached background shells the runtime knows about. * * @returns Refresh metadata for any detached background shells the runtime knows about. Use after a long pause to pick up exit/output state for shells running outside the agent loop. */ - refresh: async (): Promise => - connection.sendRequest("session.tasks.refresh", { sessionId }), + refresh: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.tasks.refresh", { sessionId }); + }, /** * Waits for all in-flight background tasks and any follow-up turns to settle. * * @returns Wait until all in-flight background tasks (agents + shells) and any follow-up turns scheduled by their completions have settled. Returns when the runtime is fully drained or after an internal timeout (default 10 minutes; configurable via COPILOT_TASK_WAIT_TIMEOUT_SECONDS). */ - waitForPending: async (): Promise => - connection.sendRequest("session.tasks.waitForPending", { sessionId }), + waitForPending: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.tasks.waitForPending", { sessionId }); + }, /** * Returns progress information for a background task by ID. * @@ -8418,15 +8603,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Progress information for the task, or null when no task with that ID is tracked. */ - getProgress: async (params: TasksGetProgressRequest): Promise => - connection.sendRequest("session.tasks.getProgress", { sessionId, ...params }), + getProgress: async (params: TasksGetProgressRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.tasks.getProgress", { sessionId, ...params }); + }, /** * Returns the first sync-waiting task that can currently be promoted to background mode. * * @returns The first sync-waiting task that can currently be promoted to background mode. */ - getCurrentPromotable: async (): Promise => - connection.sendRequest("session.tasks.getCurrentPromotable", { sessionId }), + getCurrentPromotable: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.tasks.getCurrentPromotable", { sessionId }); + }, /** * Promotes an eligible synchronously-waited task so it continues running in the background. * @@ -8434,15 +8626,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the task was successfully promoted to background mode. */ - promoteToBackground: async (params: TasksPromoteToBackgroundRequest): Promise => - connection.sendRequest("session.tasks.promoteToBackground", { sessionId, ...params }), + promoteToBackground: async (params: TasksPromoteToBackgroundRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.tasks.promoteToBackground", { sessionId, ...params }); + }, /** * Atomically promotes the first promotable sync-waiting task to background mode and returns it. * * @returns The promoted task as it now exists in background mode, omitted if no promotable task was waiting. */ - promoteCurrentToBackground: async (): Promise => - connection.sendRequest("session.tasks.promoteCurrentToBackground", { sessionId }), + promoteCurrentToBackground: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.tasks.promoteCurrentToBackground", { sessionId }); + }, /** * Cancels a background task. * @@ -8450,8 +8649,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the background task was successfully cancelled. */ - cancel: async (params: TasksCancelRequest): Promise => - connection.sendRequest("session.tasks.cancel", { sessionId, ...params }), + cancel: async (params: TasksCancelRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.tasks.cancel", { sessionId, ...params }); + }, /** * Removes a completed or cancelled background task from tracking. * @@ -8459,8 +8663,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the task was removed. False when the task does not exist or is still running/idle. */ - remove: async (params: TasksRemoveRequest): Promise => - connection.sendRequest("session.tasks.remove", { sessionId, ...params }), + remove: async (params: TasksRemoveRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.tasks.remove", { sessionId, ...params }); + }, /** * Sends a message to a background agent task. * @@ -8468,8 +8677,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the message was delivered, with an error message when delivery failed. */ - sendMessage: async (params: TasksSendMessageRequest): Promise => - connection.sendRequest("session.tasks.sendMessage", { sessionId, ...params }), + sendMessage: async (params: TasksSendMessageRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.tasks.sendMessage", { sessionId, ...params }); + }, }, /** @experimental */ skills: { @@ -8478,41 +8692,59 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Skills available to the session, with their enabled state. */ - list: async (): Promise => - connection.sendRequest("session.skills.list", { sessionId }), + list: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.skills.list", { sessionId }); + }, /** * Returns the skills that have been invoked during this session. * * @returns Skills invoked during this session, ordered by invocation time (most recent last). */ - getInvoked: async (): Promise => - connection.sendRequest("session.skills.getInvoked", { sessionId }), + getInvoked: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.skills.getInvoked", { sessionId }); + }, /** * Enables a skill for the session. * * @param params Name of the skill to enable for the session. */ - enable: async (params: SkillsEnableRequest): Promise => - connection.sendRequest("session.skills.enable", { sessionId, ...params }), + enable: async (params: SkillsEnableRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.skills.enable", { sessionId, ...params }); + }, /** * Disables a skill for the session. * * @param params Name of the skill to disable for the session. */ - disable: async (params: SkillsDisableRequest): Promise => - connection.sendRequest("session.skills.disable", { sessionId, ...params }), + disable: async (params: SkillsDisableRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.skills.disable", { sessionId, ...params }); + }, /** * Reloads skill definitions for the session. * * @returns Diagnostics from reloading skill definitions, with warnings and errors as separate lists. */ - reload: async (): Promise => - connection.sendRequest("session.skills.reload", { sessionId }), + reload: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.skills.reload", { sessionId }); + }, /** * Ensures the session's skill definitions have been loaded from disk. */ - ensureLoaded: async (): Promise => - connection.sendRequest("session.skills.ensureLoaded", { sessionId }), + ensureLoaded: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.skills.ensureLoaded", { sessionId }); + }, }, /** @experimental */ mcp: { @@ -8521,27 +8753,41 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns MCP servers configured for the session, with their connection status. */ - list: async (): Promise => - connection.sendRequest("session.mcp.list", { sessionId }), + list: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.mcp.list", { sessionId }); + }, /** * Enables an MCP server for the session. * * @param params Name of the MCP server to enable for the session. */ - enable: async (params: McpEnableRequest): Promise => - connection.sendRequest("session.mcp.enable", { sessionId, ...params }), + enable: async (params: McpEnableRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.mcp.enable", { sessionId, ...params }); + }, /** * Disables an MCP server for the session. * * @param params Name of the MCP server to disable for the session. */ - disable: async (params: McpDisableRequest): Promise => - connection.sendRequest("session.mcp.disable", { sessionId, ...params }), + disable: async (params: McpDisableRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.mcp.disable", { sessionId, ...params }); + }, /** * Reloads MCP server connections for the session. */ - reload: async (): Promise => - connection.sendRequest("session.mcp.reload", { sessionId }), + reload: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.mcp.reload", { sessionId }); + }, /** * Runs an MCP sampling inference on behalf of an MCP server. * @@ -8549,8 +8795,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Outcome of an MCP sampling execution: success result, failure error, or cancellation. */ - executeSampling: async (params: McpExecuteSamplingParams): Promise => - connection.sendRequest("session.mcp.executeSampling", { sessionId, ...params }), + executeSampling: async (params: McpExecuteSamplingParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.mcp.executeSampling", { sessionId, ...params }); + }, /** * Cancels an in-flight MCP sampling execution by request ID. * @@ -8558,8 +8809,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether an in-flight sampling execution with the given requestId was found and cancelled. */ - cancelSamplingExecution: async (params: McpCancelSamplingExecutionParams): Promise => - connection.sendRequest("session.mcp.cancelSamplingExecution", { sessionId, ...params }), + cancelSamplingExecution: async (params: McpCancelSamplingExecutionParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.mcp.cancelSamplingExecution", { sessionId, ...params }); + }, /** * Sets how environment-variable values supplied to MCP servers are resolved (direct or indirect). * @@ -8567,15 +8823,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Env-value mode recorded on the session after the update. */ - setEnvValueMode: async (params: McpSetEnvValueModeParams): Promise => - connection.sendRequest("session.mcp.setEnvValueMode", { sessionId, ...params }), + setEnvValueMode: async (params: McpSetEnvValueModeParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.mcp.setEnvValueMode", { sessionId, ...params }); + }, /** * Removes the auto-managed `github` MCP server when present. * * @returns Indicates whether the auto-managed `github` MCP server was removed (false when nothing to remove). */ - removeGitHub: async (): Promise => - connection.sendRequest("session.mcp.removeGitHub", { sessionId }), + removeGitHub: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.mcp.removeGitHub", { sessionId }); + }, /** @experimental */ oauth: { /** @@ -8585,8 +8848,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns OAuth authorization URL the caller should open, or empty when cached tokens already authenticated the server. */ - login: async (params: McpOauthLoginRequest): Promise => - connection.sendRequest("session.mcp.oauth.login", { sessionId, ...params }), + login: async (params: McpOauthLoginRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.mcp.oauth.login", { sessionId, ...params }); + }, }, }, /** @experimental */ @@ -8596,8 +8864,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Plugins installed for the session, with their enabled state and version metadata. */ - list: async (): Promise => - connection.sendRequest("session.plugins.list", { sessionId }), + list: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.plugins.list", { sessionId }); + }, }, /** @experimental */ options: { @@ -8608,8 +8878,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the session options patch was applied successfully. */ - update: async (params: SessionUpdateOptionsParams): Promise => - connection.sendRequest("session.options.update", { sessionId, ...params }), + update: async (params?: SessionUpdateOptionsParams): Promise => { + assertActive?.(); + return connection.sendRequest("session.options.update", { sessionId, ...params }); + }, }, /** @experimental */ lsp: { @@ -8618,8 +8890,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @param params Parameters for (re)loading the merged LSP configuration set. */ - initialize: async (params: LspInitializeRequest): Promise => - connection.sendRequest("session.lsp.initialize", { sessionId, ...params }), + initialize: async (params?: LspInitializeRequest): Promise => { + assertActive?.(); + return connection.sendRequest("session.lsp.initialize", { sessionId, ...params }); + }, }, /** @experimental */ extensions: { @@ -8628,27 +8902,41 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Extensions discovered for the session, with their current status. */ - list: async (): Promise => - connection.sendRequest("session.extensions.list", { sessionId }), + list: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.extensions.list", { sessionId }); + }, /** * Enables an extension for the session. * * @param params Source-qualified extension identifier to enable for the session. */ - enable: async (params: ExtensionsEnableRequest): Promise => - connection.sendRequest("session.extensions.enable", { sessionId, ...params }), + enable: async (params: ExtensionsEnableRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.extensions.enable", { sessionId, ...params }); + }, /** * Disables an extension for the session. * * @param params Source-qualified extension identifier to disable for the session. */ - disable: async (params: ExtensionsDisableRequest): Promise => - connection.sendRequest("session.extensions.disable", { sessionId, ...params }), + disable: async (params: ExtensionsDisableRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.extensions.disable", { sessionId, ...params }); + }, /** * Reloads extension definitions and processes for the session. */ - reload: async (): Promise => - connection.sendRequest("session.extensions.reload", { sessionId }), + reload: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.extensions.reload", { sessionId }); + }, }, tools: { /** @@ -8658,15 +8946,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the external tool call result was handled successfully. */ - handlePendingToolCall: async (params: HandlePendingToolCallRequest): Promise => - connection.sendRequest("session.tools.handlePendingToolCall", { sessionId, ...params }), + handlePendingToolCall: async (params: HandlePendingToolCallRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.tools.handlePendingToolCall", { sessionId, ...params }); + }, /** * Resolves, builds, and validates the runtime tool list for the session. * * @returns Resolve, build, and validate the runtime tool list for this session. Subagent sessions and consumer flows that need an initialized tool set before `send` invoke this. Default base-class implementation is a no-op for sessions that don't support tool validation. */ - initializeAndValidate: async (): Promise => - connection.sendRequest("session.tools.initializeAndValidate", { sessionId }), + initializeAndValidate: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.tools.initializeAndValidate", { sessionId }); + }, }, commands: { /** @@ -8676,8 +8971,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Slash commands available in the session, after applying any include/exclude filters. */ - list: async (params?: CommandsListRequest): Promise => - connection.sendRequest("session.commands.list", { sessionId, ...params }), + list: async (params?: CommandsListRequest): Promise => { + assertActive?.(); + return connection.sendRequest("session.commands.list", { sessionId, ...params }); + }, /** * Invokes a slash command in the session. * @@ -8685,8 +8982,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Result of invoking the slash command (text output, prompt to send to the agent, or completion). */ - invoke: async (params: CommandsInvokeRequest): Promise => - connection.sendRequest("session.commands.invoke", { sessionId, ...params }), + invoke: async (params: CommandsInvokeRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.commands.invoke", { sessionId, ...params }); + }, /** * Reports completion of a pending client-handled slash command. * @@ -8694,8 +8996,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the pending client-handled command was completed successfully. */ - handlePendingCommand: async (params: CommandsHandlePendingCommandRequest): Promise => - connection.sendRequest("session.commands.handlePendingCommand", { sessionId, ...params }), + handlePendingCommand: async (params: CommandsHandlePendingCommandRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.commands.handlePendingCommand", { sessionId, ...params }); + }, /** * Executes a slash command synchronously and returns any error. * @@ -8703,8 +9010,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Error message produced while executing the command, if any. */ - execute: async (params: ExecuteCommandParams): Promise => - connection.sendRequest("session.commands.execute", { sessionId, ...params }), + execute: async (params: ExecuteCommandParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.commands.execute", { sessionId, ...params }); + }, /** * Enqueues a slash command for FIFO processing on the local session. * @@ -8712,8 +9024,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the command was accepted into the local execution queue. */ - enqueue: async (params: EnqueueCommandParams): Promise => - connection.sendRequest("session.commands.enqueue", { sessionId, ...params }), + enqueue: async (params: EnqueueCommandParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.commands.enqueue", { sessionId, ...params }); + }, /** * Reports whether the host actually executed a queued command and whether to continue processing. * @@ -8721,8 +9038,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the queued-command response was matched to a pending request. */ - respondToQueuedCommand: async (params: CommandsRespondToQueuedCommandRequest): Promise => - connection.sendRequest("session.commands.respondToQueuedCommand", { sessionId, ...params }), + respondToQueuedCommand: async (params: CommandsRespondToQueuedCommandRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.commands.respondToQueuedCommand", { sessionId, ...params }); + }, }, /** @experimental */ telemetry: { @@ -8731,8 +9053,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @param params Feature override key/value pairs to attach to subsequent telemetry events from this session. */ - setFeatureOverrides: async (params: TelemetrySetFeatureOverridesRequest): Promise => - connection.sendRequest("session.telemetry.setFeatureOverrides", { sessionId, ...params }), + setFeatureOverrides: async (params: TelemetrySetFeatureOverridesRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.telemetry.setFeatureOverrides", { sessionId, ...params }); + }, }, ui: { /** @@ -8742,8 +9069,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns The elicitation response (accept with form values, decline, or cancel) */ - elicitation: async (params: UIElicitationRequest): Promise => - connection.sendRequest("session.ui.elicitation", { sessionId, ...params }), + elicitation: async (params: UIElicitationRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.ui.elicitation", { sessionId, ...params }); + }, /** * Provides the user response for a pending elicitation request. * @@ -8751,8 +9083,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the elicitation response was accepted; false if it was already resolved by another client. */ - handlePendingElicitation: async (params: UIHandlePendingElicitationRequest): Promise => - connection.sendRequest("session.ui.handlePendingElicitation", { sessionId, ...params }), + handlePendingElicitation: async (params: UIHandlePendingElicitationRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.ui.handlePendingElicitation", { sessionId, ...params }); + }, /** * Resolves a pending `user_input.requested` event with the user's response. * @@ -8760,8 +9097,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the pending UI request was resolved by this call. */ - handlePendingUserInput: async (params: UIHandlePendingUserInputRequest): Promise => - connection.sendRequest("session.ui.handlePendingUserInput", { sessionId, ...params }), + handlePendingUserInput: async (params: UIHandlePendingUserInputRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.ui.handlePendingUserInput", { sessionId, ...params }); + }, /** * Resolves a pending `sampling.requested` event with a sampling result, or rejects it. * @@ -8769,8 +9111,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the pending UI request was resolved by this call. */ - handlePendingSampling: async (params: UIHandlePendingSamplingRequest): Promise => - connection.sendRequest("session.ui.handlePendingSampling", { sessionId, ...params }), + handlePendingSampling: async (params: UIHandlePendingSamplingRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.ui.handlePendingSampling", { sessionId, ...params }); + }, /** * Resolves a pending `auto_mode_switch.requested` event with the user's accept/decline decision. * @@ -8778,8 +9125,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the pending UI request was resolved by this call. */ - handlePendingAutoModeSwitch: async (params: UIHandlePendingAutoModeSwitchRequest): Promise => - connection.sendRequest("session.ui.handlePendingAutoModeSwitch", { sessionId, ...params }), + handlePendingAutoModeSwitch: async (params: UIHandlePendingAutoModeSwitchRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.ui.handlePendingAutoModeSwitch", { sessionId, ...params }); + }, /** * Resolves a pending `exit_plan_mode.requested` event with the user's response. * @@ -8787,15 +9139,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the pending UI request was resolved by this call. */ - handlePendingExitPlanMode: async (params: UIHandlePendingExitPlanModeRequest): Promise => - connection.sendRequest("session.ui.handlePendingExitPlanMode", { sessionId, ...params }), + handlePendingExitPlanMode: async (params: UIHandlePendingExitPlanModeRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.ui.handlePendingExitPlanMode", { sessionId, ...params }); + }, /** * Registers an in-process handler for auto-mode-switch requests so the server bridge skips dispatch. * * @returns Register an in-process handler for `auto_mode_switch.requested` events. The caller still attaches the actual listener via the standard event-subscription mechanism; this registration solely tells the server bridge to skip its own dispatch (so a remote client doesn't race the in-process handler for the same requestId). */ - registerDirectAutoModeSwitchHandler: async (): Promise => - connection.sendRequest("session.ui.registerDirectAutoModeSwitchHandler", { sessionId }), + registerDirectAutoModeSwitchHandler: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.ui.registerDirectAutoModeSwitchHandler", { sessionId }); + }, /** * Unregisters a previously-registered in-process auto-mode-switch handler by its opaque handle. * @@ -8803,8 +9162,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the handle was active and the registration count was decremented. */ - unregisterDirectAutoModeSwitchHandler: async (params: UIUnregisterDirectAutoModeSwitchHandlerRequest): Promise => - connection.sendRequest("session.ui.unregisterDirectAutoModeSwitchHandler", { sessionId, ...params }), + unregisterDirectAutoModeSwitchHandler: async (params: UIUnregisterDirectAutoModeSwitchHandlerRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.ui.unregisterDirectAutoModeSwitchHandler", { sessionId, ...params }); + }, }, permissions: { /** @@ -8814,8 +9178,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - configure: async (params: PermissionsConfigureParams): Promise => - connection.sendRequest("session.permissions.configure", { sessionId, ...params }), + configure: async (params?: PermissionsConfigureParams): Promise => { + assertActive?.(); + return connection.sendRequest("session.permissions.configure", { sessionId, ...params }); + }, /** * Provides a decision for a pending tool permission request. * @@ -8823,15 +9189,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the permission decision was applied; false when the request was already resolved. */ - handlePendingPermissionRequest: async (params: PermissionDecisionRequest): Promise => - connection.sendRequest("session.permissions.handlePendingPermissionRequest", { sessionId, ...params }), + handlePendingPermissionRequest: async (params: PermissionDecisionRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.handlePendingPermissionRequest", { sessionId, ...params }); + }, /** * Reconstructs the set of pending tool permission requests from the session's event history. * * @returns List of pending permission requests reconstructed from event history. */ - pendingRequests: async (): Promise => - connection.sendRequest("session.permissions.pendingRequests", { sessionId }), + pendingRequests: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.permissions.pendingRequests", { sessionId }); + }, /** * Enables or disables automatic approval of tool permission requests for the session. * @@ -8839,8 +9212,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - setApproveAll: async (params: PermissionsSetApproveAllRequest): Promise => - connection.sendRequest("session.permissions.setApproveAll", { sessionId, ...params }), + setApproveAll: async (params: PermissionsSetApproveAllRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.setApproveAll", { sessionId, ...params }); + }, /** * Adds or removes session-scoped or location-scoped permission rules. * @@ -8848,8 +9226,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - modifyRules: async (params: PermissionsModifyRulesParams): Promise => - connection.sendRequest("session.permissions.modifyRules", { sessionId, ...params }), + modifyRules: async (params: PermissionsModifyRulesParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.modifyRules", { sessionId, ...params }); + }, /** * Sets whether the client wants permission prompts bridged into session events. * @@ -8857,15 +9240,22 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - setRequired: async (params: PermissionsSetRequiredRequest): Promise => - connection.sendRequest("session.permissions.setRequired", { sessionId, ...params }), + setRequired: async (params: PermissionsSetRequiredRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.setRequired", { sessionId, ...params }); + }, /** * Clears session-scoped tool permission approvals. * * @returns Indicates whether the operation succeeded. */ - resetSessionApprovals: async (): Promise => - connection.sendRequest("session.permissions.resetSessionApprovals", { sessionId }), + resetSessionApprovals: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.permissions.resetSessionApprovals", { sessionId }); + }, /** * Notifies the runtime that a permission prompt UI has been shown to the user. * @@ -8873,16 +9263,23 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - notifyPromptShown: async (params: PermissionPromptShownNotification): Promise => - connection.sendRequest("session.permissions.notifyPromptShown", { sessionId, ...params }), + notifyPromptShown: async (params: PermissionPromptShownNotification): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.notifyPromptShown", { sessionId, ...params }); + }, paths: { /** * Returns the session's allowed directories and primary working directory. * * @returns Snapshot of the session's allow-listed directories and primary working directory. */ - list: async (): Promise => - connection.sendRequest("session.permissions.paths.list", { sessionId }), + list: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.permissions.paths.list", { sessionId }); + }, /** * Adds a directory to the session's allow-list. * @@ -8890,8 +9287,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - add: async (params: PermissionPathsAddParams): Promise => - connection.sendRequest("session.permissions.paths.add", { sessionId, ...params }), + add: async (params: PermissionPathsAddParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.paths.add", { sessionId, ...params }); + }, /** * Updates the session's primary working directory used by the permission policy. * @@ -8899,8 +9301,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - updatePrimary: async (params: PermissionPathsUpdatePrimaryParams): Promise => - connection.sendRequest("session.permissions.paths.updatePrimary", { sessionId, ...params }), + updatePrimary: async (params: PermissionPathsUpdatePrimaryParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.paths.updatePrimary", { sessionId, ...params }); + }, /** * Reports whether a path falls within any of the session's allowed directories. * @@ -8908,8 +9315,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the supplied path is within the session's allowed directories. */ - isPathWithinAllowedDirectories: async (params: PermissionPathsAllowedCheckParams): Promise => - connection.sendRequest("session.permissions.paths.isPathWithinAllowedDirectories", { sessionId, ...params }), + isPathWithinAllowedDirectories: async (params: PermissionPathsAllowedCheckParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.paths.isPathWithinAllowedDirectories", { sessionId, ...params }); + }, /** * Reports whether a path falls within the session's workspace (primary) directory. * @@ -8917,8 +9329,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the supplied path is within the session's workspace directory. */ - isPathWithinWorkspace: async (params: PermissionPathsWorkspaceCheckParams): Promise => - connection.sendRequest("session.permissions.paths.isPathWithinWorkspace", { sessionId, ...params }), + isPathWithinWorkspace: async (params: PermissionPathsWorkspaceCheckParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.paths.isPathWithinWorkspace", { sessionId, ...params }); + }, }, urls: { /** @@ -8928,8 +9345,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - setUnrestrictedMode: async (params: PermissionUrlsSetUnrestrictedModeParams): Promise => - connection.sendRequest("session.permissions.urls.setUnrestrictedMode", { sessionId, ...params }), + setUnrestrictedMode: async (params: PermissionUrlsSetUnrestrictedModeParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.permissions.urls.setUnrestrictedMode", { sessionId, ...params }); + }, }, }, /** @@ -8939,8 +9361,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Identifier of the session event that was emitted for the log message. */ - log: async (params: LogRequest): Promise => - connection.sendRequest("session.log", { sessionId, ...params }), + log: async (params: LogRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.log", { sessionId, ...params }); + }, /** @experimental */ metadata: { /** @@ -8948,15 +9375,19 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Point-in-time snapshot of slow-changing session identifier and state fields */ - snapshot: async (): Promise => - connection.sendRequest("session.metadata.snapshot", { sessionId }), + snapshot: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.metadata.snapshot", { sessionId }); + }, /** * Reports whether the local session is currently processing user/agent messages. * * @returns Indicates whether the local session is currently processing a turn or background continuation. */ - isProcessing: async (): Promise => - connection.sendRequest("session.metadata.isProcessing", { sessionId }), + isProcessing: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.metadata.isProcessing", { sessionId }); + }, /** * Returns the token breakdown for the session's current context window for a given model. * @@ -8964,8 +9395,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Token breakdown for the session's current context window, or null if uninitialized. */ - contextInfo: async (params: MetadataContextInfoRequest): Promise => - connection.sendRequest("session.metadata.contextInfo", { sessionId, ...params }), + contextInfo: async (params: MetadataContextInfoRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.metadata.contextInfo", { sessionId, ...params }); + }, /** * Records a working-directory/git context change and emits a `session.context_changed` event. * @@ -8973,8 +9409,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Notify the session that its working directory context has changed. Emits a `session.context_changed` event so consumers (telemetry, OTel tracker, ACP, the timeline UI) can react. Use this when the host has detected a cwd/branch/repo change outside the session's normal lifecycle (e.g., after a shell command in interactive mode). */ - recordContextChange: async (params: MetadataRecordContextChangeRequest): Promise => - connection.sendRequest("session.metadata.recordContextChange", { sessionId, ...params }), + recordContextChange: async (params: MetadataRecordContextChangeRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.metadata.recordContextChange", { sessionId, ...params }); + }, /** * Updates the session's recorded working directory. * @@ -8982,8 +9423,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Update the session's working directory. Used by the host when the user explicitly changes cwd (e.g., the `/cd` slash command). The host is responsible for `process.chdir` and any related side-effects (file index, etc.); this method only updates the session's own recorded path. */ - setWorkingDirectory: async (params: MetadataSetWorkingDirectoryRequest): Promise => - connection.sendRequest("session.metadata.setWorkingDirectory", { sessionId, ...params }), + setWorkingDirectory: async (params: MetadataSetWorkingDirectoryRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.metadata.setWorkingDirectory", { sessionId, ...params }); + }, /** * Re-tokenizes the session's existing messages against a model and returns aggregate token totals. * @@ -8991,8 +9437,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Re-tokenize the session's existing messages against `modelId` and return the token totals. Useful for hosts that want an initial estimate of context usage on session resume, before the next agent turn fires `session.context_info_changed` events. Returns zeros for an empty session. */ - recomputeContextTokens: async (params: MetadataRecomputeContextTokensRequest): Promise => - connection.sendRequest("session.metadata.recomputeContextTokens", { sessionId, ...params }), + recomputeContextTokens: async (params: MetadataRecomputeContextTokensRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.metadata.recomputeContextTokens", { sessionId, ...params }); + }, }, shell: { /** @@ -9002,8 +9453,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Identifier of the spawned process, used to correlate streamed output and exit notifications. */ - exec: async (params: ShellExecRequest): Promise => - connection.sendRequest("session.shell.exec", { sessionId, ...params }), + exec: async (params: ShellExecRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.shell.exec", { sessionId, ...params }); + }, /** * Sends a signal to a shell process previously started via "shell.exec". * @@ -9011,8 +9467,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the signal was delivered; false if the process was unknown or already exited. */ - kill: async (params: ShellKillRequest): Promise => - connection.sendRequest("session.shell.kill", { sessionId, ...params }), + kill: async (params: ShellKillRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.shell.kill", { sessionId, ...params }); + }, }, /** @experimental */ history: { @@ -9021,8 +9482,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Compaction outcome with the number of tokens and messages removed, summary text, and the resulting context window breakdown. */ - compact: async (): Promise => - connection.sendRequest("session.history.compact", { sessionId }), + compact: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.history.compact", { sessionId }); + }, /** * Truncates persisted session history to a specific event. * @@ -9030,29 +9493,40 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Number of events that were removed by the truncation. */ - truncate: async (params: HistoryTruncateRequest): Promise => - connection.sendRequest("session.history.truncate", { sessionId, ...params }), + truncate: async (params: HistoryTruncateRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.history.truncate", { sessionId, ...params }); + }, /** * Cancels any in-progress background compaction on a local session. * * @returns Indicates whether an in-progress background compaction was cancelled. */ - cancelBackgroundCompaction: async (): Promise => - connection.sendRequest("session.history.cancelBackgroundCompaction", { sessionId }), + cancelBackgroundCompaction: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.history.cancelBackgroundCompaction", { sessionId }); + }, /** * Aborts any in-progress manual compaction on a local session. * * @returns Indicates whether an in-progress manual compaction was aborted. */ - abortManualCompaction: async (): Promise => - connection.sendRequest("session.history.abortManualCompaction", { sessionId }), + abortManualCompaction: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.history.abortManualCompaction", { sessionId }); + }, /** * Produces a markdown summary of the session's conversation context for hand-off scenarios. * * @returns Markdown summary of the conversation context (empty when not available). */ - summarizeForHandoff: async (): Promise => - connection.sendRequest("session.history.summarizeForHandoff", { sessionId }), + summarizeForHandoff: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.history.summarizeForHandoff", { sessionId }); + }, }, /** @experimental */ queue: { @@ -9061,20 +9535,26 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Snapshot of the session's pending queued items and immediate-steering messages. */ - pendingItems: async (): Promise => - connection.sendRequest("session.queue.pendingItems", { sessionId }), + pendingItems: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.queue.pendingItems", { sessionId }); + }, /** * Removes the most recently queued user-facing item (LIFO). * * @returns Indicates whether a user-facing pending item was removed. */ - removeMostRecent: async (): Promise => - connection.sendRequest("session.queue.removeMostRecent", { sessionId }), + removeMostRecent: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.queue.removeMostRecent", { sessionId }); + }, /** * Clears all pending queued items on the local session. */ - clear: async (): Promise => - connection.sendRequest("session.queue.clear", { sessionId }), + clear: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.queue.clear", { sessionId }); + }, }, /** @experimental */ eventLog: { @@ -9085,15 +9565,19 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Batch of session events returned by a read, with cursor and continuation metadata. */ - read: async (params: EventLogReadRequest): Promise => - connection.sendRequest("session.eventLog.read", { sessionId, ...params }), + read: async (params?: EventLogReadRequest): Promise => { + assertActive?.(); + return connection.sendRequest("session.eventLog.read", { sessionId, ...params }); + }, /** * Returns a snapshot of the current tail cursor without consuming events. * * @returns Snapshot of the current tail cursor without returning any events. Use this when a consumer wants to subscribe to live events going forward without first paginating through the entire persisted history (which would happen if `read` were called without a cursor on a long-lived session). */ - tail: async (): Promise => - connection.sendRequest("session.eventLog.tail", { sessionId }), + tail: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.eventLog.tail", { sessionId }); + }, /** * Registers consumer interest in an event type for runtime gating purposes. * @@ -9101,8 +9585,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Opaque handle representing an event-type interest registration. */ - registerInterest: async (params: RegisterEventInterestParams): Promise => - connection.sendRequest("session.eventLog.registerInterest", { sessionId, ...params }), + registerInterest: async (params: RegisterEventInterestParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.eventLog.registerInterest", { sessionId, ...params }); + }, /** * Releases a consumer's previously-registered interest in an event type. * @@ -9110,8 +9599,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Indicates whether the operation succeeded. */ - releaseInterest: async (params: ReleaseEventInterestParams): Promise => - connection.sendRequest("session.eventLog.releaseInterest", { sessionId, ...params }), + releaseInterest: async (params: ReleaseEventInterestParams): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.eventLog.releaseInterest", { sessionId, ...params }); + }, }, /** @experimental */ usage: { @@ -9120,8 +9614,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Accumulated session usage metrics, including premium request cost, token counts, model breakdown, and code-change totals. */ - getMetrics: async (): Promise => - connection.sendRequest("session.usage.getMetrics", { sessionId }), + getMetrics: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.usage.getMetrics", { sessionId }); + }, }, /** @experimental */ remote: { @@ -9132,13 +9628,17 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns GitHub URL for the session and a flag indicating whether remote steering is enabled. */ - enable: async (params: RemoteEnableRequest): Promise => - connection.sendRequest("session.remote.enable", { sessionId, ...params }), + enable: async (params?: RemoteEnableRequest): Promise => { + assertActive?.(); + return connection.sendRequest("session.remote.enable", { sessionId, ...params }); + }, /** * Disables remote session export and steering. */ - disable: async (): Promise => - connection.sendRequest("session.remote.disable", { sessionId }), + disable: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.remote.disable", { sessionId }); + }, /** * Persists a remote-steerability change emitted by the host as a session event. * @@ -9146,8 +9646,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Persist a steerability change as a `session.remote_steerable_changed` event. Used by the host (CLI / SDK consumer) when it has just finished enabling or disabling steering on a remote exporter that the runtime does not directly own. */ - notifySteerableChanged: async (params: RemoteNotifySteerableChangedRequest): Promise => - connection.sendRequest("session.remote.notifySteerableChanged", { sessionId, ...params }), + notifySteerableChanged: async (params: RemoteNotifySteerableChangedRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.remote.notifySteerableChanged", { sessionId, ...params }); + }, }, /** @experimental */ schedule: { @@ -9156,8 +9661,10 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Snapshot of the currently active recurring prompts for this session. */ - list: async (): Promise => - connection.sendRequest("session.schedule.list", { sessionId }), + list: async (): Promise => { + assertActive?.(); + return connection.sendRequest("session.schedule.list", { sessionId }); + }, /** * Removes a scheduled prompt by id. * @@ -9165,8 +9672,13 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin * * @returns Remove a scheduled prompt by id. The result entry is omitted if the id was unknown. */ - stop: async (params: ScheduleStopRequest): Promise => - connection.sendRequest("session.schedule.stop", { sessionId, ...params }), + stop: async (params: ScheduleStopRequest): Promise => { + assertActive?.(); + if (params == null) { + throw new TypeError("params is required"); + } + return connection.sendRequest("session.schedule.stop", { sessionId, ...params }); + }, }, }; } diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 6b164cb15..2bcb6c4a2 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -97,6 +97,9 @@ export class CopilotSession { private _rpc: ReturnType | null = null; private traceContextProvider?: TraceContextProvider; private _capabilities: SessionCapabilities = {}; + private disconnected = false; + private disconnectPromise?: Promise; + private readonly onDisconnected?: (session: CopilotSession) => void; /** @internal Client session API handlers, populated by CopilotClient during create/resume. */ clientSessionApis: ClientSessionApiHandlers = {}; @@ -114,17 +117,22 @@ export class CopilotSession { public readonly sessionId: string, private connection: MessageConnection, private _workspacePath?: string, - traceContextProvider?: TraceContextProvider + traceContextProvider?: TraceContextProvider, + onDisconnected?: (session: CopilotSession) => void ) { this.traceContextProvider = traceContextProvider; + this.onDisconnected = onDisconnected; } /** * Typed session-scoped RPC methods. */ get rpc(): ReturnType { + this.assertActive(); if (!this._rpc) { - this._rpc = createSessionRpc(this.connection, this.sessionId); + this._rpc = createSessionRpc(this.connection, this.sessionId, () => + this.assertActive() + ); } return this._rpc; } @@ -159,6 +167,7 @@ export class CopilotSession { * ``` */ get ui(): SessionUiApi { + this.assertActive(); return { elicitation: (params: ElicitationParams) => this._elicitation(params), confirm: (message: string) => this._confirm(message), @@ -186,6 +195,7 @@ export class CopilotSession { * ``` */ async send(options: MessageOptions): Promise { + this.assertActive(); const response = await this.connection.sendRequest("session.send", { ...(await getTraceContext(this.traceContextProvider)), sessionId: this.sessionId, @@ -225,6 +235,7 @@ export class CopilotSession { options: MessageOptions, timeout?: number ): Promise { + this.assertActive(); const effectiveTimeout = timeout ?? 60_000; let resolveIdle: () => void; @@ -328,6 +339,7 @@ export class CopilotSession { eventTypeOrHandler: K | SessionEventHandler, handler?: TypedSessionEventHandler ): () => void { + this.assertActive(); // Overload 1: on(eventType, handler) - typed event subscription if (typeof eventTypeOrHandler === "string" && handler) { const eventType = eventTypeOrHandler; @@ -490,8 +502,14 @@ export class CopilotSession { } else { result = JSON.stringify(rawResult); } + if (this.disconnected) { + return; + } await this.rpc.tools.handlePendingToolCall({ requestId, result }); } catch (error) { + if (this.disconnected) { + return; + } const message = error instanceof Error ? error.message : String(error); try { await this.rpc.tools.handlePendingToolCall({ requestId, error: message }); @@ -519,8 +537,14 @@ export class CopilotSession { if (result.kind === "no-result") { return; } + if (this.disconnected) { + return; + } await this.rpc.permissions.handlePendingPermissionRequest({ requestId, result }); } catch (_error) { + if (this.disconnected) { + return; + } try { await this.rpc.permissions.handlePendingPermissionRequest({ requestId, @@ -564,8 +588,14 @@ export class CopilotSession { try { await handler({ sessionId: this.sessionId, command, commandName, args }); + if (this.disconnected) { + return; + } await this.rpc.commands.handlePendingCommand({ requestId }); } catch (error) { + if (this.disconnected) { + return; + } const message = error instanceof Error ? error.message : String(error); try { await this.rpc.commands.handlePendingCommand({ requestId, error: message }); @@ -720,7 +750,14 @@ export class CopilotSession { this._capabilities = capabilities ?? {}; } + private assertActive(): void { + if (this.disconnected) { + throw new Error("Session has been disconnected."); + } + } + private assertElicitation(): void { + this.assertActive(); if (!this._capabilities.ui?.elicitation) { throw new Error( "Elicitation is not supported by the host. " + @@ -992,6 +1029,7 @@ export class CopilotSession { * ``` */ async getMessages(): Promise { + this.assertActive(); const response = await this.connection.sendRequest("session.getMessages", { sessionId: this.sessionId, }); @@ -1021,17 +1059,48 @@ export class CopilotSession { * ``` */ async disconnect(): Promise { - await this.connection.sendRequest("session.destroy", { - sessionId: this.sessionId, - }); + if (this.disconnected) { + return; + } + if (!this.disconnectPromise) { + this.disconnectPromise = this.disconnectCore(); + } + await this.disconnectPromise; + } + + private async disconnectCore(): Promise { + try { + await this.connection.sendRequest("session.destroy", { + sessionId: this.sessionId, + }); + } finally { + this.markDisconnected(); + } + } + + /** @internal Marks the session unusable after client-side forced cleanup. */ + _markDisconnected(): void { + this.markDisconnected(); + } + + private markDisconnected(): void { + if (this.disconnected) { + return; + } + this.disconnected = true; + this._rpc = null; this.eventHandlers.clear(); this.typedEventHandlers.clear(); this.toolHandlers.clear(); + this.commandHandlers.clear(); this.permissionHandler = undefined; this.userInputHandler = undefined; this.elicitationHandler = undefined; this.exitPlanModeHandler = undefined; this.autoModeSwitchHandler = undefined; + this.hooks = undefined; + this.transformCallbacks = undefined; + this.onDisconnected?.(this); } /** @@ -1073,6 +1142,7 @@ export class CopilotSession { * ``` */ async abort(): Promise { + this.assertActive(); await this.connection.sendRequest("session.abort", { sessionId: this.sessionId, }); @@ -1098,6 +1168,7 @@ export class CopilotSession { modelCapabilities?: ModelCapabilitiesOverride; } ): Promise { + this.assertActive(); await this.rpc.model.switchTo({ modelId: model, ...options }); } @@ -1121,6 +1192,7 @@ export class CopilotSession { message: string, options?: { level?: "info" | "warning" | "error"; ephemeral?: boolean } ): Promise { + this.assertActive(); await this.rpc.log({ message, ...options }); } } diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index a92f54253..519de4c09 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -1,10 +1,14 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, expect, it, onTestFinished, vi } from "vitest"; import { approveAll, CopilotClient, type ModelInfo } from "../src/index.js"; +import { createServerRpc } from "../src/generated/rpc.js"; import { CopilotSession } from "../src/session.js"; import { defaultJoinSessionPermissionHandler } from "../src/types.js"; // This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.ts instead +function markInactiveForResume(session: CopilotSession): void { + session._markDisconnected(); +} describe("CopilotClient", () => { it("allows createSession without onPermissionRequest", async () => { @@ -82,6 +86,110 @@ describe("CopilotClient", () => { ); }); + it("keeps a directly disconnected session registered until session.destroy completes", async () => { + const client = new CopilotClient({ autoStart: false }); + let session: CopilotSession; + const connection = { + sendRequest: vi.fn(async (method: string) => { + if (method === "session.destroy") { + expect((client as any).sessions.get("session-1")).toBe(session); + } + return {}; + }), + } as any; + session = new CopilotSession("session-1", connection, undefined, undefined, (s) => + (client as any).unregisterSession(s) + ); + (client as any).registerSession(session); + + await session.disconnect(); + + expect(connection.sendRequest).toHaveBeenCalledWith("session.destroy", { + sessionId: "session-1", + }); + expect((client as any).sessions.has("session-1")).toBe(false); + await expect(session.send({ prompt: "hello" })).rejects.toThrow(/disconnected/); + expect(() => session.rpc).toThrow(/disconnected/); + }); + + it("reports stop errors when session.destroy fails", async () => { + const client = new CopilotClient({ autoStart: false }); + const connection = { + dispose: vi.fn(), + sendRequest: vi.fn(async (method: string) => { + if (method === "session.destroy") { + throw new Error("destroy failed"); + } + return {}; + }), + } as any; + (client as any).connection = connection; + const session = new CopilotSession("session-1", connection, undefined, undefined, (s) => + (client as any).unregisterSession(s) + ); + (client as any).registerSession(session); + + const errors = await client.stop(); + + expect(errors).toHaveLength(1); + expect(errors[0].message).toBe("Failed to disconnect session session-1: destroy failed"); + expect(connection.sendRequest).toHaveBeenCalledTimes(1); + expect((client as any).sessions.size).toBe(0); + await expect(session.send({ prompt: "hello" })).rejects.toThrow(/disconnected/); + }); + + it("does not unregister a replacement session when a stale session disconnects", () => { + const client = new CopilotClient({ autoStart: false }); + const connection = { sendRequest: vi.fn() } as any; + const stale = new CopilotSession("session-1", connection, undefined, undefined, (s) => + (client as any).unregisterSession(s) + ); + const replacement = new CopilotSession("session-1", connection, undefined, undefined, (s) => + (client as any).unregisterSession(s) + ); + + (client as any).sessions.set("session-1", replacement); + stale._markDisconnected(); + + expect((client as any).sessions.get("session-1")).toBe(replacement); + replacement._markDisconnected(); + }); + + it("rejects duplicate active session registrations", () => { + const client = new CopilotClient({ autoStart: false }); + const connection = { sendRequest: vi.fn() } as any; + const first = new CopilotSession("session-1", connection); + const second = new CopilotSession("session-1", connection); + + (client as any).registerSession(first); + + expect(() => (client as any).registerSession(second)).toThrow(/already active/); + first._markDisconnected(); + }); + + it("validates required generated RPC params before sending requests", async () => { + const connection = { sendRequest: vi.fn() } as any; + const session = new CopilotSession("session-1", connection); + + await expect((session.rpc.commands.invoke as any)()).rejects.toThrow("params is required"); + expect(connection.sendRequest).not.toHaveBeenCalled(); + }); + + it("allows generated RPC params with only optional fields to be omitted", async () => { + const connection = { + sendRequest: vi.fn(async (method: string) => + method === "models.list" ? { models: [] } : { tools: [] } + ), + } as any; + const rpc = createServerRpc(connection); + + await rpc.models.list(); + await rpc.tools.list(); + + expect(connection.sendRequest).toHaveBeenCalledWith("models.list", {}); + expect(connection.sendRequest).toHaveBeenCalledWith("tools.list", {}); + }); + it("forwards clientName in session.resume request", async () => { const client = new CopilotClient(); await client.start(); @@ -95,6 +203,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { clientName: "my-app", onPermissionRequest: approveAll, @@ -136,6 +245,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { enableSessionTelemetry: false, onPermissionRequest: approveAll, @@ -187,6 +297,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll }); const payload = spy.mock.calls.find((c) => c[0] === "session.resume")![1] as any; @@ -206,6 +317,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll, includeSubAgentStreamingEvents: false, @@ -228,6 +340,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll, continuePendingWork: true, @@ -250,6 +363,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll }); const payload = spy.mock.calls.find((c) => c[0] === "session.resume")![1] as any; @@ -308,6 +422,7 @@ describe("CopilotClient", () => { throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll, provider: { @@ -360,6 +475,7 @@ describe("CopilotClient", () => { const session = await client.createSession({ onPermissionRequest: approveAll }); const spy = vi.spyOn((client as any).connection!, "sendRequest"); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { defaultAgent: { excludedTools: ["heavy-tool"] }, onPermissionRequest: approveAll, @@ -404,6 +520,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { instructionDirectories, onPermissionRequest: approveAll, @@ -432,6 +549,7 @@ describe("CopilotClient", () => { throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: defaultJoinSessionPermissionHandler, }); @@ -459,6 +577,7 @@ describe("CopilotClient", () => { throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll, }); @@ -486,6 +605,7 @@ describe("CopilotClient", () => { throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll, onExitPlanMode: () => ({ approved: true }), @@ -834,6 +954,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll, tools: [ @@ -912,6 +1033,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll, customAgents: [ @@ -1075,6 +1197,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll }); expect(spy).toHaveBeenCalledWith( @@ -1186,6 +1309,7 @@ describe("CopilotClient", () => { if (method === "session.resume") return { sessionId: params.sessionId }; throw new Error(`Unexpected method: ${method}`); }); + markInactiveForResume(session); await client.resumeSession(session.sessionId, { onPermissionRequest: approveAll, commands: [{ name: "deploy", description: "Deploy", handler: async () => {} }], diff --git a/nodejs/test/e2e/commands.e2e.test.ts b/nodejs/test/e2e/commands.e2e.test.ts index 5ab6a9bbe..bcee05439 100644 --- a/nodejs/test/e2e/commands.e2e.test.ts +++ b/nodejs/test/e2e/commands.e2e.test.ts @@ -6,6 +6,7 @@ import { afterAll, describe, expect, it } from "vitest"; import { CopilotClient, approveAll } from "../../src/index.js"; import type { SessionEvent } from "../../src/index.js"; import { createSdkTestContext } from "./harness/sdkTestContext.js"; +import { markInactiveForResume } from "./harness/sdkTestHelper.js"; describe("Commands", async () => { // Use TCP mode so a second client can connect to the same CLI process @@ -83,6 +84,7 @@ describe("Commands", async () => { it("session with commands resumes successfully", async () => { const session1 = await client1.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; + markInactiveForResume(session1); const session2 = await client1.resumeSession(sessionId, { onPermissionRequest: approveAll, diff --git a/nodejs/test/e2e/harness/sdkTestHelper.ts b/nodejs/test/e2e/harness/sdkTestHelper.ts index 183e216f2..b21a0d1c4 100644 --- a/nodejs/test/e2e/harness/sdkTestHelper.ts +++ b/nodejs/test/e2e/harness/sdkTestHelper.ts @@ -130,3 +130,7 @@ export function getNextEventOfType( }); }); } + +export function markInactiveForResume(session: CopilotSession): void { + session._markDisconnected(); +} diff --git a/nodejs/test/e2e/mcp_and_agents.e2e.test.ts b/nodejs/test/e2e/mcp_and_agents.e2e.test.ts index aa580cdee..e0fce8be9 100644 --- a/nodejs/test/e2e/mcp_and_agents.e2e.test.ts +++ b/nodejs/test/e2e/mcp_and_agents.e2e.test.ts @@ -49,6 +49,7 @@ describe("MCP Servers and Custom Agents", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; await session1.sendAndWait({ prompt: "What is 1+1?" }); + await session1.disconnect(); // Resume with MCP servers const mcpServers: Record = { @@ -160,6 +161,7 @@ describe("MCP Servers and Custom Agents", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; await session1.sendAndWait({ prompt: "What is 1+1?" }); + await session1.disconnect(); // Resume with custom agents const customAgents: CustomAgentConfig[] = [ @@ -338,6 +340,7 @@ describe("MCP Servers and Custom Agents", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; await session1.sendAndWait({ prompt: "What is 3+3?" }); + await session1.disconnect(); const secretTool = defineTool("secret_tool", { description: "A secret tool hidden from the default agent", diff --git a/nodejs/test/e2e/permissions.e2e.test.ts b/nodejs/test/e2e/permissions.e2e.test.ts index dcb8033b2..d1fdc4f38 100644 --- a/nodejs/test/e2e/permissions.e2e.test.ts +++ b/nodejs/test/e2e/permissions.e2e.test.ts @@ -118,6 +118,7 @@ describe("Permission callbacks", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; await session1.sendAndWait({ prompt: "What is 1+1?" }); + await session1.disconnect(); const session2 = await client.resumeSession(sessionId, { onPermissionRequest: () => ({ @@ -182,6 +183,7 @@ describe("Permission callbacks", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; await session1.sendAndWait({ prompt: "What is 1+1?" }); + await session1.disconnect(); // Resume with permission handler const session2 = await client.resumeSession(sessionId, { diff --git a/nodejs/test/e2e/session.e2e.test.ts b/nodejs/test/e2e/session.e2e.test.ts index ca9d2d9d4..3d8d3f7b9 100644 --- a/nodejs/test/e2e/session.e2e.test.ts +++ b/nodejs/test/e2e/session.e2e.test.ts @@ -3,7 +3,11 @@ import { describe, expect, it, onTestFinished, vi } from "vitest"; import { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy.js"; import { CopilotClient, approveAll, defineTool } from "../../src/index.js"; import { createSdkTestContext, isCI } from "./harness/sdkTestContext.js"; -import { getFinalAssistantMessage, getNextEventOfType } from "./harness/sdkTestHelper.js"; +import { + getFinalAssistantMessage, + getNextEventOfType, + markInactiveForResume, +} from "./harness/sdkTestHelper.js"; describe("Sessions", async () => { const { @@ -31,7 +35,7 @@ describe("Sessions", async () => { ]); await session.disconnect(); - await expect(() => session.getMessages()).rejects.toThrow(/Session not found/); + await expect(() => session.getMessages()).rejects.toThrow(/Session has been disconnected/); }); // TODO: Re-enable once test harness CAPI proxy supports this test's session lifecycle @@ -253,7 +257,7 @@ describe("Sessions", async () => { // All can be disconnected await Promise.all([s1.disconnect(), s2.disconnect(), s3.disconnect()]); for (const s of [s1, s2, s3]) { - await expect(() => s.getMessages()).rejects.toThrow(/Session not found/); + await expect(() => s.getMessages()).rejects.toThrow(/Session has been disconnected/); } }); @@ -263,6 +267,7 @@ describe("Sessions", async () => { const sessionId = session1.sessionId; const answer = await session1.sendAndWait({ prompt: "What is 1+1?" }); expect(answer?.data.content).toContain("2"); + await session1.disconnect(); // Resume using the same client const session2 = await client.resumeSession(sessionId, { onPermissionRequest: approveAll }); @@ -353,6 +358,7 @@ describe("Sessions", async () => { it("should resume session with a custom provider", async () => { const session = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session.sessionId; + markInactiveForResume(session); // Resume the session with a provider const session2 = await client.resumeSession(sessionId, { diff --git a/nodejs/test/e2e/session_config.e2e.test.ts b/nodejs/test/e2e/session_config.e2e.test.ts index b86c3fa51..79f5da8ad 100644 --- a/nodejs/test/e2e/session_config.e2e.test.ts +++ b/nodejs/test/e2e/session_config.e2e.test.ts @@ -3,6 +3,7 @@ import { writeFile, mkdir } from "fs/promises"; import { join } from "path"; import { approveAll } from "../../src/index.js"; import { createSdkTestContext } from "./harness/sdkTestContext.js"; +import { markInactiveForResume } from "./harness/sdkTestHelper.js"; describe("Session Configuration", async () => { const { copilotClient: client, workDir, openAiEndpoint } = await createSdkTestContext(); @@ -247,6 +248,7 @@ describe("Session Configuration", async () => { onPermissionRequest: approveAll, workingDirectory: projectDir, }); + markInactiveForResume(session1); const session2 = await client.resumeSession(session1.sessionId, { onPermissionRequest: approveAll, workingDirectory: projectDir, @@ -304,6 +306,7 @@ describe("Session Configuration", async () => { it("should forward custom provider headers on resume", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; + markInactiveForResume(session1); const session2 = await client.resumeSession(sessionId, { onPermissionRequest: approveAll, @@ -384,6 +387,7 @@ describe("Session Configuration", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; + markInactiveForResume(session1); const session2 = await client.resumeSession(sessionId, { onPermissionRequest: approveAll, @@ -401,6 +405,7 @@ describe("Session Configuration", async () => { it("should apply systemMessage on session resume", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; + markInactiveForResume(session1); const resumeInstruction = "End the response with RESUME_SYSTEM_MESSAGE_SENTINEL."; const session2 = await client.resumeSession(sessionId, { @@ -422,6 +427,7 @@ describe("Session Configuration", async () => { it("should apply availableTools on session resume", async () => { const session1 = await client.createSession({ onPermissionRequest: approveAll }); const sessionId = session1.sessionId; + markInactiveForResume(session1); const session2 = await client.resumeSession(sessionId, { onPermissionRequest: approveAll, diff --git a/nodejs/test/e2e/skills.e2e.test.ts b/nodejs/test/e2e/skills.e2e.test.ts index 973e2f329..f63b5f98c 100644 --- a/nodejs/test/e2e/skills.e2e.test.ts +++ b/nodejs/test/e2e/skills.e2e.test.ts @@ -162,6 +162,7 @@ IMPORTANT: You MUST include the exact text "${SKILL_MARKER}" somewhere in EVERY // First message without skill - marker should not appear const message1 = await session1.sendAndWait({ prompt: "Say hi." }); expect(message1?.data.content).not.toContain(SKILL_MARKER); + await session1.disconnect(); // Resume with skillDirectories - skill should now be active const session2 = await client.resumeSession(sessionId, { diff --git a/python/copilot/client.py b/python/copilot/client.py index 6adb52061..4e974429c 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -1022,6 +1022,18 @@ def rpc(self) -> ServerRpc: raise RuntimeError("Client is not connected. Call start() first.") return self._rpc + def _register_session(self, session: CopilotSession) -> None: + with self._sessions_lock: + existing = self._sessions.get(session.session_id) + if existing is not None and existing is not session: + raise RuntimeError(f"Session {session.session_id} is already active.") + self._sessions[session.session_id] = session + + def _unregister_session(self, session: CopilotSession) -> None: + with self._sessions_lock: + if self._sessions.get(session.session_id) is session: + del self._sessions[session.session_id] + @property def actual_port(self) -> int | None: """The actual TCP port the CLI server is listening on, if using TCP transport. @@ -1229,11 +1241,8 @@ async def stop(self) -> None: """ errors: list[StopError] = [] - # Atomically take ownership of all sessions and clear the dict - # so no other thread can access them with self._sessions_lock: sessions_to_destroy = list(self._sessions.values()) - self._sessions.clear() for session in sessions_to_destroy: try: @@ -1248,6 +1257,12 @@ async def stop(self) -> None: StopError(message=f"Failed to disconnect session {session.session_id}: {e}") ) + with self._sessions_lock: + remaining_sessions = list(self._sessions.values()) + self._sessions.clear() + for session in remaining_sessions: + session._mark_disconnected() + # Close client if self._client: await self._client.stop() @@ -1290,9 +1305,11 @@ async def force_stop(self) -> None: ... except asyncio.TimeoutError: ... await client.force_stop() """ - # Clear sessions immediately without trying to destroy them with self._sessions_lock: + sessions_to_destroy = list(self._sessions.values()) self._sessions.clear() + for session in sessions_to_destroy: + session._mark_disconnected() # Close the transport first to signal the server immediately. # For external servers (TCP), this closes the socket. @@ -1621,24 +1638,33 @@ async def create_session( # Create and register the session before issuing the RPC so that # events emitted by the CLI (e.g. session.start) are not dropped. setup_start = time.perf_counter() - session = CopilotSession(actual_session_id, self._client, workspace_path=None) - if self._session_fs_config: - if create_session_fs_handler is None: - raise ValueError( - "create_session_fs_handler is required in session config when " - "session_fs is enabled in client options." - ) - fs_provider: SessionFsProvider = create_session_fs_handler(session) - caps = self._session_fs_config.get("capabilities") - if caps and caps.get("sqlite"): - from .session_fs_provider import SessionFsSqliteProvider - - if not isinstance(fs_provider, SessionFsSqliteProvider): + session = CopilotSession( + actual_session_id, + self._client, + workspace_path=None, + on_disconnected=self._unregister_session, + ) + try: + if self._session_fs_config: + if create_session_fs_handler is None: raise ValueError( - "SessionFs capabilities declare SQLite support but the provider " - "does not implement SessionFsSqliteProvider" + "create_session_fs_handler is required in session config when " + "session_fs is enabled in client options." ) - session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) + fs_provider: SessionFsProvider = create_session_fs_handler(session) + caps = self._session_fs_config.get("capabilities") + if caps and caps.get("sqlite"): + from .session_fs_provider import SessionFsSqliteProvider + + if not isinstance(fs_provider, SessionFsSqliteProvider): + raise ValueError( + "SessionFs capabilities declare SQLite support but the provider " + "does not implement SessionFsSqliteProvider" + ) + session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) + except BaseException: + session._mark_disconnected() + raise session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) @@ -1656,8 +1682,11 @@ async def create_session( session._register_transform_callbacks(transform_callbacks) if on_event: session.on(on_event) - with self._sessions_lock: - self._sessions[actual_session_id] = session + try: + self._register_session(session) + except BaseException: + session._mark_disconnected() + raise log_timing( logger, logging.DEBUG, @@ -1683,8 +1712,8 @@ async def create_session( capabilities = response.get("capabilities") session._set_capabilities(capabilities) except BaseException as exc: - with self._sessions_lock: - self._sessions.pop(actual_session_id, None) + self._unregister_session(session) + session._mark_disconnected() if not isinstance(exc, asyncio.CancelledError): log_timing( logger, @@ -1974,24 +2003,33 @@ async def resume_session( # Create and register the session before issuing the RPC so that # events emitted by the CLI (e.g. session.start) are not dropped. setup_start = time.perf_counter() - session = CopilotSession(session_id, self._client, workspace_path=None) - if self._session_fs_config: - if create_session_fs_handler is None: - raise ValueError( - "create_session_fs_handler is required in session config when " - "session_fs is enabled in client options." - ) - fs_provider: SessionFsProvider = create_session_fs_handler(session) - caps = self._session_fs_config.get("capabilities") - if caps and caps.get("sqlite"): - from .session_fs_provider import SessionFsSqliteProvider - - if not isinstance(fs_provider, SessionFsSqliteProvider): + session = CopilotSession( + session_id, + self._client, + workspace_path=None, + on_disconnected=self._unregister_session, + ) + try: + if self._session_fs_config: + if create_session_fs_handler is None: raise ValueError( - "SessionFs capabilities declare SQLite support but the provider " - "does not implement SessionFsSqliteProvider" + "create_session_fs_handler is required in session config when " + "session_fs is enabled in client options." ) - session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) + fs_provider: SessionFsProvider = create_session_fs_handler(session) + caps = self._session_fs_config.get("capabilities") + if caps and caps.get("sqlite"): + from .session_fs_provider import SessionFsSqliteProvider + + if not isinstance(fs_provider, SessionFsSqliteProvider): + raise ValueError( + "SessionFs capabilities declare SQLite support but the provider " + "does not implement SessionFsSqliteProvider" + ) + session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) + except BaseException: + session._mark_disconnected() + raise session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) @@ -2009,8 +2047,11 @@ async def resume_session( session._register_transform_callbacks(transform_callbacks) if on_event: session.on(on_event) - with self._sessions_lock: - self._sessions[session_id] = session + try: + self._register_session(session) + except BaseException: + session._mark_disconnected() + raise log_timing( logger, logging.DEBUG, @@ -2036,8 +2077,8 @@ async def resume_session( capabilities = response.get("capabilities") session._set_capabilities(capabilities) except BaseException as exc: - with self._sessions_lock: - self._sessions.pop(session_id, None) + self._unregister_session(session) + session._mark_disconnected() if not isinstance(exc, asyncio.CancelledError): log_timing( logger, @@ -2281,8 +2322,9 @@ async def delete_session(self, session_id: str) -> None: # Remove from local sessions map if present with self._sessions_lock: - if session_id in self._sessions: - del self._sessions[session_id] + session = self._sessions.get(session_id) + if session is not None: + session._mark_disconnected() async def get_last_session_id(self) -> str | None: """ diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index b72bd97ad..ab1e3996a 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -15546,9 +15546,9 @@ class ServerModelsApi: def __init__(self, client: "JsonRpcClient"): self._client = client - async def list(self, params: ModelsListRequest, *, timeout: float | None = None) -> ModelList: + async def list(self, params: ModelsListRequest | None = None, *, timeout: float | None = None) -> ModelList: "Lists Copilot models available to the authenticated user.\n\nArgs:\n params: Optional GitHub token used to list models for a specific user instead of the global auth context.\n\nReturns:\n List of Copilot models available to the resolved user, including capabilities and billing metadata." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return ModelList.from_dict(_patch_model_capabilities(await self._client.request("models.list", params_dict, **_timeout_kwargs(timeout)))) @@ -15556,9 +15556,9 @@ class ServerToolsApi: def __init__(self, client: "JsonRpcClient"): self._client = client - async def list(self, params: ToolsListRequest, *, timeout: float | None = None) -> ToolList: + async def list(self, params: ToolsListRequest | None = None, *, timeout: float | None = None) -> ToolList: "Lists built-in tools available for a model.\n\nArgs:\n params: Optional model identifier whose tool overrides should be applied to the listing.\n\nReturns:\n Built-in tools available for the requested model, with their parameters and instructions." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return ToolList.from_dict(await self._client.request("tools.list", params_dict, **_timeout_kwargs(timeout))) @@ -15566,9 +15566,9 @@ class ServerAccountApi: def __init__(self, client: "JsonRpcClient"): self._client = client - async def get_quota(self, params: AccountGetQuotaRequest, *, timeout: float | None = None) -> AccountGetQuotaResult: + async def get_quota(self, params: AccountGetQuotaRequest | None = None, *, timeout: float | None = None) -> AccountGetQuotaResult: "Gets Copilot quota usage for the authenticated user or supplied GitHub token.\n\nArgs:\n params: Optional GitHub token used to look up quota for a specific user instead of the global auth context.\n\nReturns:\n Quota usage snapshots for the resolved user, keyed by quota type." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return AccountGetQuotaResult.from_dict(await self._client.request("account.getQuota", params_dict, **_timeout_kwargs(timeout))) @@ -15582,26 +15582,41 @@ async def list(self, *, timeout: float | None = None) -> MCPConfigList: async def add(self, params: MCPConfigAddRequest, *, timeout: float | None = None) -> None: "Adds an MCP server to user configuration.\n\nArgs:\n params: MCP server name and configuration to add to user configuration." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} await self._client.request("mcp.config.add", params_dict, **_timeout_kwargs(timeout)) async def update(self, params: MCPConfigUpdateRequest, *, timeout: float | None = None) -> None: "Updates an MCP server in user configuration.\n\nArgs:\n params: MCP server name and replacement configuration to write to user configuration." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} await self._client.request("mcp.config.update", params_dict, **_timeout_kwargs(timeout)) async def remove(self, params: MCPConfigRemoveRequest, *, timeout: float | None = None) -> None: "Removes an MCP server from user configuration.\n\nArgs:\n params: MCP server name to remove from user configuration." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} await self._client.request("mcp.config.remove", params_dict, **_timeout_kwargs(timeout)) async def enable(self, params: MCPConfigEnableRequest, *, timeout: float | None = None) -> None: "Enables MCP servers in user configuration for new sessions.\n\nArgs:\n params: MCP server names to enable for new sessions." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} await self._client.request("mcp.config.enable", params_dict, **_timeout_kwargs(timeout)) async def disable(self, params: MCPConfigDisableRequest, *, timeout: float | None = None) -> None: "Disables MCP servers in user configuration for new sessions.\n\nArgs:\n params: MCP server names to disable for new sessions." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} await self._client.request("mcp.config.disable", params_dict, **_timeout_kwargs(timeout)) @@ -15611,9 +15626,9 @@ def __init__(self, client: "JsonRpcClient"): self._client = client self.config = ServerMcpConfigApi(client) - async def discover(self, params: MCPDiscoverRequest, *, timeout: float | None = None) -> MCPDiscoverResult: + async def discover(self, params: MCPDiscoverRequest | None = None, *, timeout: float | None = None) -> MCPDiscoverResult: "Discovers MCP servers from user, workspace, plugin, and builtin sources.\n\nArgs:\n params: Optional working directory used as context for MCP server discovery.\n\nReturns:\n MCP servers discovered from user, workspace, plugin, and built-in sources." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return MCPDiscoverResult.from_dict(await self._client.request("mcp.discover", params_dict, **_timeout_kwargs(timeout))) @@ -15623,6 +15638,9 @@ def __init__(self, client: "JsonRpcClient"): async def set_disabled_skills(self, params: SkillsConfigSetDisabledSkillsRequest, *, timeout: float | None = None) -> None: "Replaces the global list of disabled skills.\n\nArgs:\n params: Skill names to mark as disabled in global configuration, replacing any previous list." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} await self._client.request("skills.config.setDisabledSkills", params_dict, **_timeout_kwargs(timeout)) @@ -15632,9 +15650,9 @@ def __init__(self, client: "JsonRpcClient"): self._client = client self.config = ServerSkillsConfigApi(client) - async def discover(self, params: SkillsDiscoverRequest, *, timeout: float | None = None) -> ServerSkillList: + async def discover(self, params: SkillsDiscoverRequest | None = None, *, timeout: float | None = None) -> ServerSkillList: "Discovers skills across global and project sources.\n\nArgs:\n params: Optional project paths and additional skill directories to include in discovery.\n\nReturns:\n Skills discovered across global and project sources." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return ServerSkillList.from_dict(await self._client.request("skills.discover", params_dict, **_timeout_kwargs(timeout))) @@ -15644,6 +15662,9 @@ def __init__(self, client: "JsonRpcClient"): async def set_provider(self, params: SessionFSSetProviderRequest, *, timeout: float | None = None) -> SessionFSSetProviderResult: "Registers an SDK client as the session filesystem provider.\n\nArgs:\n params: Initial working directory, session-state path layout, and path conventions used to register the calling SDK client as the session filesystem provider.\n\nReturns:\n Indicates whether the calling client was registered as the session filesystem provider." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} return SessionFSSetProviderResult.from_dict(await self._client.request("sessionFs.setProvider", params_dict, **_timeout_kwargs(timeout))) @@ -15653,39 +15674,45 @@ class ServerSessionsApi: def __init__(self, client: "JsonRpcClient"): self._client = client - async def fork(self, params: SessionsForkRequest, *, timeout: float | None = None) -> SessionsForkResult: + async def fork(self, params: SessionsForkRequest | None = None, *, timeout: float | None = None) -> SessionsForkResult: "Creates a new session by forking persisted history from an existing session.\n\nArgs:\n params: Source session identifier to fork from, optional event-ID boundary, and optional friendly name for the new session.\n\nReturns:\n Identifier and optional friendly name assigned to the newly forked session." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionsForkResult.from_dict(await self._client.request("sessions.fork", params_dict, **_timeout_kwargs(timeout))) - async def connect(self, params: ConnectRemoteSessionParams, *, timeout: float | None = None) -> RemoteSessionConnectionResult: + async def connect(self, params: ConnectRemoteSessionParams | None = None, *, timeout: float | None = None) -> RemoteSessionConnectionResult: "Connects to an existing remote session and exposes it as an SDK session.\n\nArgs:\n params: Remote session connection parameters.\n\nReturns:\n Remote session connection result." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return RemoteSessionConnectionResult.from_dict(await self._client.request("sessions.connect", params_dict, **_timeout_kwargs(timeout))) - async def list(self, params: SessionsListRequest, *, timeout: float | None = None) -> SessionList: + async def list(self, params: SessionsListRequest | None = None, *, timeout: float | None = None) -> SessionList: "Lists persisted sessions, optionally filtered by working-directory context.\n\nArgs:\n params: Optional metadata-load limit and context filter applied to the returned sessions.\n\nReturns:\n Persisted sessions matching the filter, ordered most-recently-modified first." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionList.from_dict(await self._client.request("sessions.list", params_dict, **_timeout_kwargs(timeout))) async def find_by_task_id(self, params: SessionsFindByTaskIDRequest, *, timeout: float | None = None) -> SessionsFindByTaskIDResult: "Finds the local session bound to a GitHub task ID, if any.\n\nArgs:\n params: GitHub task ID to look up.\n\nReturns:\n ID of the local session bound to the given GitHub task, or omitted when none." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} return SessionsFindByTaskIDResult.from_dict(await self._client.request("sessions.findByTaskId", params_dict, **_timeout_kwargs(timeout))) async def find_by_prefix(self, params: SessionsFindByPrefixRequest, *, timeout: float | None = None) -> SessionsFindByPrefixResult: "Resolves a UUID prefix to a unique session ID, if exactly one session matches.\n\nArgs:\n params: UUID prefix to resolve to a unique session ID.\n\nReturns:\n Session ID matching the prefix, omitted when no unique match exists." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} return SessionsFindByPrefixResult.from_dict(await self._client.request("sessions.findByPrefix", params_dict, **_timeout_kwargs(timeout))) - async def get_last_for_context(self, params: SessionsGetLastForContextRequest, *, timeout: float | None = None) -> SessionsGetLastForContextResult: + async def get_last_for_context(self, params: SessionsGetLastForContextRequest | None = None, *, timeout: float | None = None) -> SessionsGetLastForContextResult: "Returns the most-relevant prior session for a given working-directory context.\n\nArgs:\n params: Optional working-directory context used to score session relevance.\n\nReturns:\n Most-relevant session ID for the supplied context, or omitted when no sessions exist." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionsGetLastForContextResult.from_dict(await self._client.request("sessions.getLastForContext", params_dict, **_timeout_kwargs(timeout))) - async def get_event_file_path(self, params: SessionsGetEventFilePathRequest, *, timeout: float | None = None) -> SessionsGetEventFilePathResult: + async def get_event_file_path(self, params: SessionsGetEventFilePathRequest | None = None, *, timeout: float | None = None) -> SessionsGetEventFilePathResult: "Computes the absolute path to a session's persisted events.jsonl file.\n\nArgs:\n params: Session ID whose event-log file path to compute.\n\nReturns:\n Absolute path to the session's events.jsonl file on disk." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionsGetEventFilePathResult.from_dict(await self._client.request("sessions.getEventFilePath", params_dict, **_timeout_kwargs(timeout))) async def get_sizes(self, *, timeout: float | None = None) -> SessionSizes: @@ -15694,56 +15721,71 @@ async def get_sizes(self, *, timeout: float | None = None) -> SessionSizes: async def check_in_use(self, params: SessionsCheckInUseRequest, *, timeout: float | None = None) -> SessionsCheckInUseResult: "Returns the subset of the supplied session IDs that are currently held by another running process.\n\nArgs:\n params: Session IDs to test for live in-use locks.\n\nReturns:\n Session IDs from the input set that are currently in use by another process." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} return SessionsCheckInUseResult.from_dict(await self._client.request("sessions.checkInUse", params_dict, **_timeout_kwargs(timeout))) - async def get_persisted_remote_steerable(self, params: SessionsGetPersistedRemoteSteerableRequest, *, timeout: float | None = None) -> SessionsGetPersistedRemoteSteerableResult: + async def get_persisted_remote_steerable(self, params: SessionsGetPersistedRemoteSteerableRequest | None = None, *, timeout: float | None = None) -> SessionsGetPersistedRemoteSteerableResult: "Returns a session's persisted remote-steerable flag, if any has been recorded.\n\nArgs:\n params: Session ID to look up the persisted remote-steerable flag for.\n\nReturns:\n The session's persisted remote-steerable flag, or omitted when no value has been persisted." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionsGetPersistedRemoteSteerableResult.from_dict(await self._client.request("sessions.getPersistedRemoteSteerable", params_dict, **_timeout_kwargs(timeout))) - async def close(self, params: SessionsCloseRequest, *, timeout: float | None = None) -> SessionsCloseResult: + async def close(self, params: SessionsCloseRequest | None = None, *, timeout: float | None = None) -> SessionsCloseResult: "Closes a session: emits shutdown, flushes pending events, releases the in-use lock, and disposes the active session.\n\nArgs:\n params: Session ID to close.\n\nReturns:\n Closes a session: emits shutdown, flushes pending events to disk, releases the in-use lock, disposes the active session. Idempotent: succeeds even if the session is not currently active." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionsCloseResult.from_dict(await self._client.request("sessions.close", params_dict, **_timeout_kwargs(timeout))) async def bulk_delete(self, params: SessionsBulkDeleteRequest, *, timeout: float | None = None) -> SessionBulkDeleteResult: "Closes, deactivates, and deletes a set of sessions, returning the bytes freed per session.\n\nArgs:\n params: Session IDs to close, deactivate, and delete from disk.\n\nReturns:\n Map of sessionId -> bytes freed by removing the session's workspace directory." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} return SessionBulkDeleteResult.from_dict(await self._client.request("sessions.bulkDelete", params_dict, **_timeout_kwargs(timeout))) async def prune_old(self, params: SessionsPruneOldRequest, *, timeout: float | None = None) -> SessionPruneResult: "Deletes sessions older than the given threshold, with optional dry-run and exclusion list.\n\nArgs:\n params: Age threshold and optional flags controlling which old sessions are pruned (or simulated when dryRun is true).\n\nReturns:\n Outcome of the prune operation: deleted IDs, dry-run candidates, skipped IDs, total bytes freed, and the dry-run flag." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} return SessionPruneResult.from_dict(await self._client.request("sessions.pruneOld", params_dict, **_timeout_kwargs(timeout))) - async def save(self, params: SessionsSaveRequest, *, timeout: float | None = None) -> SessionsSaveResult: + async def save(self, params: SessionsSaveRequest | None = None, *, timeout: float | None = None) -> SessionsSaveResult: "Flushes a session's pending events to disk.\n\nArgs:\n params: Session ID whose pending events should be flushed to disk.\n\nReturns:\n Flush a session's pending events to disk. No-op when no writer exists for the session (e.g., already closed)." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionsSaveResult.from_dict(await self._client.request("sessions.save", params_dict, **_timeout_kwargs(timeout))) - async def release_lock(self, params: SessionsReleaseLockRequest, *, timeout: float | None = None) -> SessionsReleaseLockResult: + async def release_lock(self, params: SessionsReleaseLockRequest | None = None, *, timeout: float | None = None) -> SessionsReleaseLockResult: "Releases the in-use lock held by this process for a session.\n\nArgs:\n params: Session ID whose in-use lock should be released.\n\nReturns:\n Release the in-use lock held by this process for the given session. No-op when this process does not currently hold a lock for the session." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionsReleaseLockResult.from_dict(await self._client.request("sessions.releaseLock", params_dict, **_timeout_kwargs(timeout))) async def enrich_metadata(self, params: SessionsEnrichMetadataRequest, *, timeout: float | None = None) -> SessionEnrichMetadataResult: "Backfills missing summary and context fields on the supplied session metadata records.\n\nArgs:\n params: Session metadata records to enrich with summary and context information.\n\nReturns:\n The same metadata records, with summary and context fields backfilled where available." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} return SessionEnrichMetadataResult.from_dict(await self._client.request("sessions.enrichMetadata", params_dict, **_timeout_kwargs(timeout))) - async def reload_plugin_hooks(self, params: SessionsReloadPluginHooksRequest, *, timeout: float | None = None) -> SessionsReloadPluginHooksResult: + async def reload_plugin_hooks(self, params: SessionsReloadPluginHooksRequest | None = None, *, timeout: float | None = None) -> SessionsReloadPluginHooksResult: "Reloads user, plugin, and (optionally) repo hooks on the active session.\n\nArgs:\n params: Active session ID and an optional flag for deferring repo-level hooks until folder trust.\n\nReturns:\n Reload all hooks (user, plugin, optionally repo) and apply them to the active session. Call after installing or removing plugins so their hooks take effect immediately. No-op when no active session matches the given sessionId." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionsReloadPluginHooksResult.from_dict(await self._client.request("sessions.reloadPluginHooks", params_dict, **_timeout_kwargs(timeout))) - async def load_deferred_repo_hooks(self, params: SessionsLoadDeferredRepoHooksRequest, *, timeout: float | None = None) -> SessionLoadDeferredRepoHooksResult: + async def load_deferred_repo_hooks(self, params: SessionsLoadDeferredRepoHooksRequest | None = None, *, timeout: float | None = None) -> SessionLoadDeferredRepoHooksResult: "Loads previously-deferred repo-level hooks on the active session, returning queued startup prompts.\n\nArgs:\n params: Active session ID whose deferred repo-level hooks should be loaded.\n\nReturns:\n Queued repo-level startup prompts and the total hook command count after loading." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return SessionLoadDeferredRepoHooksResult.from_dict(await self._client.request("sessions.loadDeferredRepoHooks", params_dict, **_timeout_kwargs(timeout))) async def set_additional_plugins(self, params: SessionsSetAdditionalPluginsRequest, *, timeout: float | None = None) -> SessionsSetAdditionalPluginsResult: "Replaces the manager-wide additional plugins registered with the session manager.\n\nArgs:\n params: Manager-wide additional plugins to register; replaces any previously-configured set.\n\nReturns:\n Replace the manager-wide additional plugins. New session creations and subsequent hook reloads see the new set; already-running sessions keep their existing hook installation until the next reload." + if params is None: + raise TypeError("params is required") + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} return SessionsSetAdditionalPluginsResult.from_dict(await self._client.request("sessions.setAdditionalPlugins", params_dict, **_timeout_kwargs(timeout))) @@ -15760,9 +15802,9 @@ def __init__(self, client: "JsonRpcClient"): self.session_fs = ServerSessionFsApi(client) self.sessions = ServerSessionsApi(client) - async def ping(self, params: PingRequest, *, timeout: float | None = None) -> PingResult: + async def ping(self, params: PingRequest | None = None, *, timeout: float | None = None) -> PingResult: "Checks server responsiveness and returns protocol information.\n\nArgs:\n params: Optional message to echo back to the caller.\n\nReturns:\n Server liveness response, including the echoed message, current server timestamp, and protocol version." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return PingResult.from_dict(await self._client.request("ping", params_dict, **_timeout_kwargs(timeout))) @@ -15771,260 +15813,418 @@ class _InternalServerRpc: def __init__(self, client: "JsonRpcClient"): self._client = client - async def connect(self, params: ConnectRequest, *, timeout: float | None = None) -> ConnectResult: + async def connect(self, params: ConnectRequest | None = None, *, timeout: float | None = None) -> ConnectResult: "Performs the SDK server connection handshake and validates the optional connection token.\n\nArgs:\n params: Optional connection token presented by the SDK client during the handshake.\n\nReturns:\n Handshake result reporting the server's protocol version and package version on success.\n\n:meta private:\n\nInternal SDK API; not part of the public surface." - params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} return ConnectResult.from_dict(await self._client.request("connect", params_dict, **_timeout_kwargs(timeout))) class AuthApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def get_status(self, *, timeout: float | None = None) -> SessionAuthStatus: "Gets authentication status and account metadata for the session.\n\nReturns:\n Authentication status and account metadata for the session." + if self._assert_active is not None: + self._assert_active() + return SessionAuthStatus.from_dict(await self._client.request("session.auth.getStatus", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) - async def set_credentials(self, params: SessionSetCredentialsParams, *, timeout: float | None = None) -> SessionSetCredentialsResult: + async def set_credentials(self, params: SessionSetCredentialsParams | None = None, *, timeout: float | None = None) -> SessionSetCredentialsResult: "Updates the session's auth credentials used for outbound model and API requests.\n\nArgs:\n params: New auth credentials to install on the session. Omit to leave credentials unchanged.\n\nReturns:\n Indicates whether the credential update succeeded." - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id return SessionSetCredentialsResult.from_dict(await self._client.request("session.auth.setCredentials", params_dict, **_timeout_kwargs(timeout))) class ModelApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def get_current(self, *, timeout: float | None = None) -> CurrentModel: "Gets the currently selected model for the session.\n\nReturns:\n The currently selected model and reasoning effort for the session." + if self._assert_active is not None: + self._assert_active() + return CurrentModel.from_dict(await self._client.request("session.model.getCurrent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def switch_to(self, params: ModelSwitchToRequest, *, timeout: float | None = None) -> ModelSwitchToResult: "Switches the session to a model and optional reasoning configuration.\n\nArgs:\n params: Target model identifier and optional reasoning effort, summary, and capability overrides.\n\nReturns:\n The model identifier active on the session after the switch." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return ModelSwitchToResult.from_dict(await self._client.request("session.model.switchTo", params_dict, **_timeout_kwargs(timeout))) async def set_reasoning_effort(self, params: ModelSetReasoningEffortRequest, *, timeout: float | None = None) -> ModelSetReasoningEffortResult: "Updates the session's reasoning effort without changing the selected model.\n\nArgs:\n params: Reasoning effort level to apply to the currently selected model.\n\nReturns:\n Update the session's reasoning effort without changing the selected model. Use `switchTo` instead when you also need to change the model. The runtime stores the effort on the session and applies it to subsequent turns." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return ModelSetReasoningEffortResult.from_dict(await self._client.request("session.model.setReasoningEffort", params_dict, **_timeout_kwargs(timeout))) class ModeApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def get(self, *, timeout: float | None = None) -> SessionMode: "Gets the current agent interaction mode.\n\nReturns:\n The session mode the agent is operating in" + if self._assert_active is not None: + self._assert_active() + return SessionMode(await self._client.request("session.mode.get", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def set(self, params: ModeSetRequest, *, timeout: float | None = None) -> None: "Sets the current agent interaction mode.\n\nArgs:\n params: Agent interaction mode to apply to the session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.mode.set", params_dict, **_timeout_kwargs(timeout)) class NameApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def get(self, *, timeout: float | None = None) -> NameGetResult: "Gets the session's friendly name.\n\nReturns:\n The session's friendly name, or null when not yet set." + if self._assert_active is not None: + self._assert_active() + return NameGetResult.from_dict(await self._client.request("session.name.get", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def set(self, params: NameSetRequest, *, timeout: float | None = None) -> None: "Sets the session's friendly name.\n\nArgs:\n params: New friendly name to apply to the session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.name.set", params_dict, **_timeout_kwargs(timeout)) async def set_auto(self, params: NameSetAutoRequest, *, timeout: float | None = None) -> NameSetAutoResult: "Persists an auto-generated session summary as the session's name when no user-set name exists.\n\nArgs:\n params: Auto-generated session summary to apply as the session's name when no user-set name exists.\n\nReturns:\n Indicates whether the auto-generated summary was applied as the session's name." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return NameSetAutoResult.from_dict(await self._client.request("session.name.setAuto", params_dict, **_timeout_kwargs(timeout))) class PlanApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def read(self, *, timeout: float | None = None) -> PlanReadResult: "Reads the session plan file from the workspace.\n\nReturns:\n Existence, contents, and resolved path of the session plan file." + if self._assert_active is not None: + self._assert_active() + return PlanReadResult.from_dict(await self._client.request("session.plan.read", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def update(self, params: PlanUpdateRequest, *, timeout: float | None = None) -> None: "Writes new content to the session plan file.\n\nArgs:\n params: Replacement contents to write to the session plan file." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.plan.update", params_dict, **_timeout_kwargs(timeout)) async def delete(self, *, timeout: float | None = None) -> None: "Deletes the session plan file from the workspace." + if self._assert_active is not None: + self._assert_active() + await self._client.request("session.plan.delete", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) class WorkspacesApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def get_workspace(self, *, timeout: float | None = None) -> WorkspacesGetWorkspaceResult: "Gets current workspace metadata for the session.\n\nReturns:\n Current workspace metadata for the session, including its absolute filesystem path when available." + if self._assert_active is not None: + self._assert_active() + return WorkspacesGetWorkspaceResult.from_dict(await self._client.request("session.workspaces.getWorkspace", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def list_files(self, *, timeout: float | None = None) -> WorkspacesListFilesResult: "Lists files stored in the session workspace files directory.\n\nReturns:\n Relative paths of files stored in the session workspace files directory." + if self._assert_active is not None: + self._assert_active() + return WorkspacesListFilesResult.from_dict(await self._client.request("session.workspaces.listFiles", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def read_file(self, params: WorkspacesReadFileRequest, *, timeout: float | None = None) -> WorkspacesReadFileResult: "Reads a file from the session workspace files directory.\n\nArgs:\n params: Relative path of the workspace file to read.\n\nReturns:\n Contents of the requested workspace file as a UTF-8 string." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return WorkspacesReadFileResult.from_dict(await self._client.request("session.workspaces.readFile", params_dict, **_timeout_kwargs(timeout))) async def create_file(self, params: WorkspacesCreateFileRequest, *, timeout: float | None = None) -> None: "Creates or overwrites a file in the session workspace files directory.\n\nArgs:\n params: Relative path and UTF-8 content for the workspace file to create or overwrite." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.workspaces.createFile", params_dict, **_timeout_kwargs(timeout)) async def list_checkpoints(self, *, timeout: float | None = None) -> WorkspacesListCheckpointsResult: "Lists workspace checkpoints in chronological order.\n\nReturns:\n Workspace checkpoints in chronological order; empty when the workspace is not enabled." + if self._assert_active is not None: + self._assert_active() + return WorkspacesListCheckpointsResult.from_dict(await self._client.request("session.workspaces.listCheckpoints", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def read_checkpoint(self, params: WorkspacesReadCheckpointRequest, *, timeout: float | None = None) -> WorkspacesReadCheckpointResult: "Reads the content of a workspace checkpoint by number.\n\nArgs:\n params: Checkpoint number to read.\n\nReturns:\n Checkpoint content as a UTF-8 string, or null when the checkpoint or workspace is missing." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return WorkspacesReadCheckpointResult.from_dict(await self._client.request("session.workspaces.readCheckpoint", params_dict, **_timeout_kwargs(timeout))) async def save_large_paste(self, params: WorkspacesSaveLargePasteRequest, *, timeout: float | None = None) -> WorkspacesSaveLargePasteResult: "Saves pasted content as a UTF-8 file in the session workspace.\n\nArgs:\n params: Pasted content to save as a UTF-8 file in the session workspace.\n\nReturns:\n Descriptor for the saved paste file, or null when the workspace is unavailable." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return WorkspacesSaveLargePasteResult.from_dict(await self._client.request("session.workspaces.saveLargePaste", params_dict, **_timeout_kwargs(timeout))) class InstructionsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def get_sources(self, *, timeout: float | None = None) -> InstructionsGetSourcesResult: "Gets instruction sources loaded for the session.\n\nReturns:\n Instruction sources loaded for the session, in merge order." + if self._assert_active is not None: + self._assert_active() + return InstructionsGetSourcesResult.from_dict(await self._client.request("session.instructions.getSources", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) # Experimental: this API group is experimental and may change or be removed. class FleetApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active - async def start(self, params: FleetStartRequest, *, timeout: float | None = None) -> FleetStartResult: + async def start(self, params: FleetStartRequest | None = None, *, timeout: float | None = None) -> FleetStartResult: "Starts fleet mode by submitting the fleet orchestration prompt to the session.\n\nArgs:\n params: Optional user prompt to combine with the fleet orchestration instructions.\n\nReturns:\n Indicates whether fleet mode was successfully activated." - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id return FleetStartResult.from_dict(await self._client.request("session.fleet.start", params_dict, **_timeout_kwargs(timeout))) # Experimental: this API group is experimental and may change or be removed. class AgentApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def list(self, *, timeout: float | None = None) -> AgentList: "Lists custom agents available to the session.\n\nReturns:\n Custom agents available to the session." + if self._assert_active is not None: + self._assert_active() + return AgentList.from_dict(await self._client.request("session.agent.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def get_current(self, *, timeout: float | None = None) -> AgentGetCurrentResult: "Gets the currently selected custom agent for the session.\n\nReturns:\n The currently selected custom agent, or null when using the default agent." + if self._assert_active is not None: + self._assert_active() + return AgentGetCurrentResult.from_dict(await self._client.request("session.agent.getCurrent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def select(self, params: AgentSelectRequest, *, timeout: float | None = None) -> AgentSelectResult: "Selects a custom agent for subsequent turns in the session.\n\nArgs:\n params: Name of the custom agent to select for subsequent turns.\n\nReturns:\n The newly selected custom agent." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return AgentSelectResult.from_dict(await self._client.request("session.agent.select", params_dict, **_timeout_kwargs(timeout))) async def deselect(self, *, timeout: float | None = None) -> None: "Clears the selected custom agent and returns the session to the default agent." + if self._assert_active is not None: + self._assert_active() + await self._client.request("session.agent.deselect", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) async def reload(self, *, timeout: float | None = None) -> AgentReloadResult: "Reloads custom agent definitions and returns the refreshed list.\n\nReturns:\n Custom agents available to the session after reloading definitions from disk." + if self._assert_active is not None: + self._assert_active() + return AgentReloadResult.from_dict(await self._client.request("session.agent.reload", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) # Experimental: this API group is experimental and may change or be removed. class TasksApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def start_agent(self, params: TasksStartAgentRequest, *, timeout: float | None = None) -> TasksStartAgentResult: "Starts a background agent task in the session.\n\nArgs:\n params: Agent type, prompt, name, and optional description and model override for the new task.\n\nReturns:\n Identifier assigned to the newly started background agent task." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return TasksStartAgentResult.from_dict(await self._client.request("session.tasks.startAgent", params_dict, **_timeout_kwargs(timeout))) async def list(self, *, timeout: float | None = None) -> TaskList: "Lists background tasks tracked by the session.\n\nReturns:\n Background tasks currently tracked by the session." + if self._assert_active is not None: + self._assert_active() + return TaskList.from_dict(await self._client.request("session.tasks.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def refresh(self, *, timeout: float | None = None) -> TasksRefreshResult: "Refreshes metadata for any detached background shells the runtime knows about.\n\nReturns:\n Refresh metadata for any detached background shells the runtime knows about. Use after a long pause to pick up exit/output state for shells running outside the agent loop." + if self._assert_active is not None: + self._assert_active() + return TasksRefreshResult.from_dict(await self._client.request("session.tasks.refresh", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def wait_for_pending(self, *, timeout: float | None = None) -> TasksWaitForPendingResult: "Waits for all in-flight background tasks and any follow-up turns to settle.\n\nReturns:\n Wait until all in-flight background tasks (agents + shells) and any follow-up turns scheduled by their completions have settled. Returns when the runtime is fully drained or after an internal timeout (default 10 minutes; configurable via COPILOT_TASK_WAIT_TIMEOUT_SECONDS)." + if self._assert_active is not None: + self._assert_active() + return TasksWaitForPendingResult.from_dict(await self._client.request("session.tasks.waitForPending", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def get_progress(self, params: TasksGetProgressRequest, *, timeout: float | None = None) -> TasksGetProgressResult: "Returns progress information for a background task by ID.\n\nArgs:\n params: Identifier of the background task to fetch progress for.\n\nReturns:\n Progress information for the task, or null when no task with that ID is tracked." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return TasksGetProgressResult.from_dict(await self._client.request("session.tasks.getProgress", params_dict, **_timeout_kwargs(timeout))) async def get_current_promotable(self, *, timeout: float | None = None) -> TasksGetCurrentPromotableResult: "Returns the first sync-waiting task that can currently be promoted to background mode.\n\nReturns:\n The first sync-waiting task that can currently be promoted to background mode." + if self._assert_active is not None: + self._assert_active() + return TasksGetCurrentPromotableResult.from_dict(await self._client.request("session.tasks.getCurrentPromotable", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def promote_to_background(self, params: TasksPromoteToBackgroundRequest, *, timeout: float | None = None) -> TasksPromoteToBackgroundResult: "Promotes an eligible synchronously-waited task so it continues running in the background.\n\nArgs:\n params: Identifier of the task to promote to background mode.\n\nReturns:\n Indicates whether the task was successfully promoted to background mode." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return TasksPromoteToBackgroundResult.from_dict(await self._client.request("session.tasks.promoteToBackground", params_dict, **_timeout_kwargs(timeout))) async def promote_current_to_background(self, *, timeout: float | None = None) -> TasksPromoteCurrentToBackgroundResult: "Atomically promotes the first promotable sync-waiting task to background mode and returns it.\n\nReturns:\n The promoted task as it now exists in background mode, omitted if no promotable task was waiting." + if self._assert_active is not None: + self._assert_active() + return TasksPromoteCurrentToBackgroundResult.from_dict(await self._client.request("session.tasks.promoteCurrentToBackground", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def cancel(self, params: TasksCancelRequest, *, timeout: float | None = None) -> TasksCancelResult: "Cancels a background task.\n\nArgs:\n params: Identifier of the background task to cancel.\n\nReturns:\n Indicates whether the background task was successfully cancelled." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return TasksCancelResult.from_dict(await self._client.request("session.tasks.cancel", params_dict, **_timeout_kwargs(timeout))) async def remove(self, params: TasksRemoveRequest, *, timeout: float | None = None) -> TasksRemoveResult: "Removes a completed or cancelled background task from tracking.\n\nArgs:\n params: Identifier of the completed or cancelled task to remove from tracking.\n\nReturns:\n Indicates whether the task was removed. False when the task does not exist or is still running/idle." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return TasksRemoveResult.from_dict(await self._client.request("session.tasks.remove", params_dict, **_timeout_kwargs(timeout))) async def send_message(self, params: TasksSendMessageRequest, *, timeout: float | None = None) -> TasksSendMessageResult: "Sends a message to a background agent task.\n\nArgs:\n params: Identifier of the target agent task, message content, and optional sender agent ID.\n\nReturns:\n Indicates whether the message was delivered, with an error message when delivery failed." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return TasksSendMessageResult.from_dict(await self._client.request("session.tasks.sendMessage", params_dict, **_timeout_kwargs(timeout))) @@ -16032,47 +16232,76 @@ async def send_message(self, params: TasksSendMessageRequest, *, timeout: float # Experimental: this API group is experimental and may change or be removed. class SkillsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def list(self, *, timeout: float | None = None) -> SkillList: "Lists skills available to the session.\n\nReturns:\n Skills available to the session, with their enabled state." + if self._assert_active is not None: + self._assert_active() + return SkillList.from_dict(await self._client.request("session.skills.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def get_invoked(self, *, timeout: float | None = None) -> SkillsGetInvokedResult: "Returns the skills that have been invoked during this session.\n\nReturns:\n Skills invoked during this session, ordered by invocation time (most recent last)." + if self._assert_active is not None: + self._assert_active() + return SkillsGetInvokedResult.from_dict(await self._client.request("session.skills.getInvoked", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def enable(self, params: SkillsEnableRequest, *, timeout: float | None = None) -> None: "Enables a skill for the session.\n\nArgs:\n params: Name of the skill to enable for the session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.skills.enable", params_dict, **_timeout_kwargs(timeout)) async def disable(self, params: SkillsDisableRequest, *, timeout: float | None = None) -> None: "Disables a skill for the session.\n\nArgs:\n params: Name of the skill to disable for the session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.skills.disable", params_dict, **_timeout_kwargs(timeout)) async def reload(self, *, timeout: float | None = None) -> SkillsLoadDiagnostics: "Reloads skill definitions for the session.\n\nReturns:\n Diagnostics from reloading skill definitions, with warnings and errors as separate lists." + if self._assert_active is not None: + self._assert_active() + return SkillsLoadDiagnostics.from_dict(await self._client.request("session.skills.reload", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def ensure_loaded(self, *, timeout: float | None = None) -> None: "Ensures the session's skill definitions have been loaded from disk." + if self._assert_active is not None: + self._assert_active() + await self._client.request("session.skills.ensureLoaded", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) # Experimental: this API group is experimental and may change or be removed. class McpOauthApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def login(self, params: MCPOauthLoginRequest, *, timeout: float | None = None) -> MCPOauthLoginResult: "Starts OAuth authentication for a remote MCP server.\n\nArgs:\n params: Remote MCP server name and optional overrides controlling reauthentication, OAuth client display name, and the callback success-page copy.\n\nReturns:\n OAuth authorization URL the caller should open, or empty when cached tokens already authenticated the server." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return MCPOauthLoginResult.from_dict(await self._client.request("session.mcp.oauth.login", params_dict, **_timeout_kwargs(timeout))) @@ -16080,171 +16309,273 @@ async def login(self, params: MCPOauthLoginRequest, *, timeout: float | None = N # Experimental: this API group is experimental and may change or be removed. class McpApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id - self.oauth = McpOauthApi(client, session_id) + self._assert_active = assert_active + self.oauth = McpOauthApi(client, session_id, assert_active) async def list(self, *, timeout: float | None = None) -> MCPServerList: "Lists MCP servers configured for the session and their connection status.\n\nReturns:\n MCP servers configured for the session, with their connection status." + if self._assert_active is not None: + self._assert_active() + return MCPServerList.from_dict(await self._client.request("session.mcp.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def enable(self, params: MCPEnableRequest, *, timeout: float | None = None) -> None: "Enables an MCP server for the session.\n\nArgs:\n params: Name of the MCP server to enable for the session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.mcp.enable", params_dict, **_timeout_kwargs(timeout)) async def disable(self, params: MCPDisableRequest, *, timeout: float | None = None) -> None: "Disables an MCP server for the session.\n\nArgs:\n params: Name of the MCP server to disable for the session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.mcp.disable", params_dict, **_timeout_kwargs(timeout)) async def reload(self, *, timeout: float | None = None) -> None: "Reloads MCP server connections for the session." + if self._assert_active is not None: + self._assert_active() + await self._client.request("session.mcp.reload", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) async def execute_sampling(self, params: MCPExecuteSamplingParams, *, timeout: float | None = None) -> MCPSamplingExecutionResult: "Runs an MCP sampling inference on behalf of an MCP server.\n\nArgs:\n params: Identifiers and raw MCP CreateMessageRequest params used to run a sampling inference.\n\nReturns:\n Outcome of an MCP sampling execution: success result, failure error, or cancellation." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return MCPSamplingExecutionResult.from_dict(await self._client.request("session.mcp.executeSampling", params_dict, **_timeout_kwargs(timeout))) async def cancel_sampling_execution(self, params: MCPCancelSamplingExecutionParams, *, timeout: float | None = None) -> MCPCancelSamplingExecutionResult: "Cancels an in-flight MCP sampling execution by request ID.\n\nArgs:\n params: The requestId previously passed to executeSampling that should be cancelled.\n\nReturns:\n Indicates whether an in-flight sampling execution with the given requestId was found and cancelled." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return MCPCancelSamplingExecutionResult.from_dict(await self._client.request("session.mcp.cancelSamplingExecution", params_dict, **_timeout_kwargs(timeout))) async def set_env_value_mode(self, params: MCPSetEnvValueModeParams, *, timeout: float | None = None) -> MCPSetEnvValueModeResult: "Sets how environment-variable values supplied to MCP servers are resolved (direct or indirect).\n\nArgs:\n params: Mode controlling how MCP server env values are resolved (`direct` or `indirect`).\n\nReturns:\n Env-value mode recorded on the session after the update." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return MCPSetEnvValueModeResult.from_dict(await self._client.request("session.mcp.setEnvValueMode", params_dict, **_timeout_kwargs(timeout))) async def remove_git_hub(self, *, timeout: float | None = None) -> MCPRemoveGitHubResult: "Removes the auto-managed `github` MCP server when present.\n\nReturns:\n Indicates whether the auto-managed `github` MCP server was removed (false when nothing to remove)." + if self._assert_active is not None: + self._assert_active() + return MCPRemoveGitHubResult.from_dict(await self._client.request("session.mcp.removeGitHub", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) # Experimental: this API group is experimental and may change or be removed. class PluginsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def list(self, *, timeout: float | None = None) -> PluginList: "Lists plugins installed for the session.\n\nReturns:\n Plugins installed for the session, with their enabled state and version metadata." + if self._assert_active is not None: + self._assert_active() + return PluginList.from_dict(await self._client.request("session.plugins.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) # Experimental: this API group is experimental and may change or be removed. class OptionsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active - async def update(self, params: SessionUpdateOptionsParams, *, timeout: float | None = None) -> SessionUpdateOptionsResult: + async def update(self, params: SessionUpdateOptionsParams | None = None, *, timeout: float | None = None) -> SessionUpdateOptionsResult: "Patches the genuinely-mutable subset of session options.\n\nArgs:\n params: Patch of mutable session options to apply to the running session.\n\nReturns:\n Indicates whether the session options patch was applied successfully." - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id return SessionUpdateOptionsResult.from_dict(await self._client.request("session.options.update", params_dict, **_timeout_kwargs(timeout))) # Experimental: this API group is experimental and may change or be removed. class LspApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active - async def initialize(self, params: LspInitializeRequest, *, timeout: float | None = None) -> None: + async def initialize(self, params: LspInitializeRequest | None = None, *, timeout: float | None = None) -> None: "Loads the merged LSP configuration set for the session's working directory.\n\nArgs:\n params: Parameters for (re)loading the merged LSP configuration set." - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id await self._client.request("session.lsp.initialize", params_dict, **_timeout_kwargs(timeout)) # Experimental: this API group is experimental and may change or be removed. class ExtensionsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def list(self, *, timeout: float | None = None) -> ExtensionList: "Lists extensions discovered for the session and their current status.\n\nReturns:\n Extensions discovered for the session, with their current status." + if self._assert_active is not None: + self._assert_active() + return ExtensionList.from_dict(await self._client.request("session.extensions.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def enable(self, params: ExtensionsEnableRequest, *, timeout: float | None = None) -> None: "Enables an extension for the session.\n\nArgs:\n params: Source-qualified extension identifier to enable for the session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.extensions.enable", params_dict, **_timeout_kwargs(timeout)) async def disable(self, params: ExtensionsDisableRequest, *, timeout: float | None = None) -> None: "Disables an extension for the session.\n\nArgs:\n params: Source-qualified extension identifier to disable for the session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.extensions.disable", params_dict, **_timeout_kwargs(timeout)) async def reload(self, *, timeout: float | None = None) -> None: "Reloads extension definitions and processes for the session." + if self._assert_active is not None: + self._assert_active() + await self._client.request("session.extensions.reload", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) class ToolsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def handle_pending_tool_call(self, params: HandlePendingToolCallRequest, *, timeout: float | None = None) -> HandlePendingToolCallResult: "Provides the result for a pending external tool call.\n\nArgs:\n params: Pending external tool call request ID, with the tool result or an error describing why it failed.\n\nReturns:\n Indicates whether the external tool call result was handled successfully." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return HandlePendingToolCallResult.from_dict(await self._client.request("session.tools.handlePendingToolCall", params_dict, **_timeout_kwargs(timeout))) async def initialize_and_validate(self, *, timeout: float | None = None) -> ToolsInitializeAndValidateResult: "Resolves, builds, and validates the runtime tool list for the session.\n\nReturns:\n Resolve, build, and validate the runtime tool list for this session. Subagent sessions and consumer flows that need an initialized tool set before `send` invoke this. Default base-class implementation is a no-op for sessions that don't support tool validation." + if self._assert_active is not None: + self._assert_active() + return ToolsInitializeAndValidateResult.from_dict(await self._client.request("session.tools.initializeAndValidate", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) class CommandsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def list(self, params: CommandsListRequest | None = None, *, timeout: float | None = None) -> CommandList: "Lists slash commands available in the session.\n\nArgs:\n params: Optional filters controlling which command sources to include in the listing.\n\nReturns:\n Slash commands available in the session, after applying any include/exclude filters." + if self._assert_active is not None: + self._assert_active() + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id return CommandList.from_dict(await self._client.request("session.commands.list", params_dict, **_timeout_kwargs(timeout))) async def invoke(self, params: CommandsInvokeRequest, *, timeout: float | None = None) -> SlashCommandInvocationResult: "Invokes a slash command in the session.\n\nArgs:\n params: Slash command name and optional raw input string to invoke.\n\nReturns:\n Result of invoking the slash command (text output, prompt to send to the agent, or completion)." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return SlashCommandInvocationResult.from_dict(await self._client.request("session.commands.invoke", params_dict, **_timeout_kwargs(timeout))) async def handle_pending_command(self, params: CommandsHandlePendingCommandRequest, *, timeout: float | None = None) -> CommandsHandlePendingCommandResult: "Reports completion of a pending client-handled slash command.\n\nArgs:\n params: Pending command request ID and an optional error if the client handler failed.\n\nReturns:\n Indicates whether the pending client-handled command was completed successfully." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return CommandsHandlePendingCommandResult.from_dict(await self._client.request("session.commands.handlePendingCommand", params_dict, **_timeout_kwargs(timeout))) async def execute(self, params: ExecuteCommandParams, *, timeout: float | None = None) -> ExecuteCommandResult: "Executes a slash command synchronously and returns any error.\n\nArgs:\n params: Slash command name and argument string to execute synchronously.\n\nReturns:\n Error message produced while executing the command, if any." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return ExecuteCommandResult.from_dict(await self._client.request("session.commands.execute", params_dict, **_timeout_kwargs(timeout))) async def enqueue(self, params: EnqueueCommandParams, *, timeout: float | None = None) -> EnqueueCommandResult: "Enqueues a slash command for FIFO processing on the local session.\n\nArgs:\n params: Slash-prefixed command string to enqueue for FIFO processing.\n\nReturns:\n Indicates whether the command was accepted into the local execution queue." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return EnqueueCommandResult.from_dict(await self._client.request("session.commands.enqueue", params_dict, **_timeout_kwargs(timeout))) async def respond_to_queued_command(self, params: CommandsRespondToQueuedCommandRequest, *, timeout: float | None = None) -> CommandsRespondToQueuedCommandResult: "Reports whether the host actually executed a queued command and whether to continue processing.\n\nArgs:\n params: Queued-command request ID and the result indicating whether the host executed it (and whether to stop processing further queued commands).\n\nReturns:\n Indicates whether the queued-command response was matched to a pending request." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return CommandsRespondToQueuedCommandResult.from_dict(await self._client.request("session.commands.respondToQueuedCommand", params_dict, **_timeout_kwargs(timeout))) @@ -16252,162 +16583,272 @@ async def respond_to_queued_command(self, params: CommandsRespondToQueuedCommand # Experimental: this API group is experimental and may change or be removed. class TelemetryApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def set_feature_overrides(self, params: TelemetrySetFeatureOverridesRequest, *, timeout: float | None = None) -> None: "Sets feature override key/value pairs to attach to subsequent telemetry events for the session.\n\nArgs:\n params: Feature override key/value pairs to attach to subsequent telemetry events from this session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id await self._client.request("session.telemetry.setFeatureOverrides", params_dict, **_timeout_kwargs(timeout)) class UiApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def elicitation(self, params: UIElicitationRequest, *, timeout: float | None = None) -> UIElicitationResponse: "Requests structured input from a UI-capable client.\n\nArgs:\n params: Prompt message and JSON schema describing the form fields to elicit from the user.\n\nReturns:\n The elicitation response (accept with form values, decline, or cancel)" + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return UIElicitationResponse.from_dict(await self._client.request("session.ui.elicitation", params_dict, **_timeout_kwargs(timeout))) async def handle_pending_elicitation(self, params: UIHandlePendingElicitationRequest, *, timeout: float | None = None) -> UIElicitationResult: "Provides the user response for a pending elicitation request.\n\nArgs:\n params: Pending elicitation request ID and the user's response (accept/decline/cancel + form values).\n\nReturns:\n Indicates whether the elicitation response was accepted; false if it was already resolved by another client." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return UIElicitationResult.from_dict(await self._client.request("session.ui.handlePendingElicitation", params_dict, **_timeout_kwargs(timeout))) async def handle_pending_user_input(self, params: UIHandlePendingUserInputRequest, *, timeout: float | None = None) -> UIHandlePendingResult: "Resolves a pending `user_input.requested` event with the user's response.\n\nArgs:\n params: Request ID of a pending `user_input.requested` event and the user's response.\n\nReturns:\n Indicates whether the pending UI request was resolved by this call." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return UIHandlePendingResult.from_dict(await self._client.request("session.ui.handlePendingUserInput", params_dict, **_timeout_kwargs(timeout))) async def handle_pending_sampling(self, params: UIHandlePendingSamplingRequest, *, timeout: float | None = None) -> UIHandlePendingResult: "Resolves a pending `sampling.requested` event with a sampling result, or rejects it.\n\nArgs:\n params: Request ID of a pending `sampling.requested` event and an optional sampling result payload (omit to reject).\n\nReturns:\n Indicates whether the pending UI request was resolved by this call." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return UIHandlePendingResult.from_dict(await self._client.request("session.ui.handlePendingSampling", params_dict, **_timeout_kwargs(timeout))) async def handle_pending_auto_mode_switch(self, params: UIHandlePendingAutoModeSwitchRequest, *, timeout: float | None = None) -> UIHandlePendingResult: "Resolves a pending `auto_mode_switch.requested` event with the user's accept/decline decision.\n\nArgs:\n params: Request ID of a pending `auto_mode_switch.requested` event and the user's response.\n\nReturns:\n Indicates whether the pending UI request was resolved by this call." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return UIHandlePendingResult.from_dict(await self._client.request("session.ui.handlePendingAutoModeSwitch", params_dict, **_timeout_kwargs(timeout))) async def handle_pending_exit_plan_mode(self, params: UIHandlePendingExitPlanModeRequest, *, timeout: float | None = None) -> UIHandlePendingResult: "Resolves a pending `exit_plan_mode.requested` event with the user's response.\n\nArgs:\n params: Request ID of a pending `exit_plan_mode.requested` event and the user's response.\n\nReturns:\n Indicates whether the pending UI request was resolved by this call." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return UIHandlePendingResult.from_dict(await self._client.request("session.ui.handlePendingExitPlanMode", params_dict, **_timeout_kwargs(timeout))) async def register_direct_auto_mode_switch_handler(self, *, timeout: float | None = None) -> UIRegisterDirectAutoModeSwitchHandlerResult: "Registers an in-process handler for auto-mode-switch requests so the server bridge skips dispatch.\n\nReturns:\n Register an in-process handler for `auto_mode_switch.requested` events. The caller still attaches the actual listener via the standard event-subscription mechanism; this registration solely tells the server bridge to skip its own dispatch (so a remote client doesn't race the in-process handler for the same requestId)." + if self._assert_active is not None: + self._assert_active() + return UIRegisterDirectAutoModeSwitchHandlerResult.from_dict(await self._client.request("session.ui.registerDirectAutoModeSwitchHandler", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def unregister_direct_auto_mode_switch_handler(self, params: UIUnregisterDirectAutoModeSwitchHandlerRequest, *, timeout: float | None = None) -> UIUnregisterDirectAutoModeSwitchHandlerResult: "Unregisters a previously-registered in-process auto-mode-switch handler by its opaque handle.\n\nArgs:\n params: Opaque handle previously returned by `registerDirectAutoModeSwitchHandler` to release.\n\nReturns:\n Indicates whether the handle was active and the registration count was decremented." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return UIUnregisterDirectAutoModeSwitchHandlerResult.from_dict(await self._client.request("session.ui.unregisterDirectAutoModeSwitchHandler", params_dict, **_timeout_kwargs(timeout))) class PermissionsPathsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def list(self, *, timeout: float | None = None) -> PermissionPathsList: "Returns the session's allowed directories and primary working directory.\n\nReturns:\n Snapshot of the session's allow-listed directories and primary working directory." + if self._assert_active is not None: + self._assert_active() + return PermissionPathsList.from_dict(await self._client.request("session.permissions.paths.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def add(self, params: PermissionPathsAddParams, *, timeout: float | None = None) -> PermissionsPathsAddResult: "Adds a directory to the session's allow-list.\n\nArgs:\n params: Directory path to add to the session's allowed directories.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionsPathsAddResult.from_dict(await self._client.request("session.permissions.paths.add", params_dict, **_timeout_kwargs(timeout))) async def update_primary(self, params: PermissionPathsUpdatePrimaryParams, *, timeout: float | None = None) -> PermissionsPathsUpdatePrimaryResult: "Updates the session's primary working directory used by the permission policy.\n\nArgs:\n params: Directory path to set as the session's new primary working directory.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionsPathsUpdatePrimaryResult.from_dict(await self._client.request("session.permissions.paths.updatePrimary", params_dict, **_timeout_kwargs(timeout))) async def is_path_within_allowed_directories(self, params: PermissionPathsAllowedCheckParams, *, timeout: float | None = None) -> PermissionPathsAllowedCheckResult: "Reports whether a path falls within any of the session's allowed directories.\n\nArgs:\n params: Path to evaluate against the session's allowed directories.\n\nReturns:\n Indicates whether the supplied path is within the session's allowed directories." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionPathsAllowedCheckResult.from_dict(await self._client.request("session.permissions.paths.isPathWithinAllowedDirectories", params_dict, **_timeout_kwargs(timeout))) async def is_path_within_workspace(self, params: PermissionPathsWorkspaceCheckParams, *, timeout: float | None = None) -> PermissionPathsWorkspaceCheckResult: "Reports whether a path falls within the session's workspace (primary) directory.\n\nArgs:\n params: Path to evaluate against the session's workspace (primary) directory.\n\nReturns:\n Indicates whether the supplied path is within the session's workspace directory." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionPathsWorkspaceCheckResult.from_dict(await self._client.request("session.permissions.paths.isPathWithinWorkspace", params_dict, **_timeout_kwargs(timeout))) class PermissionsUrlsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def set_unrestricted_mode(self, params: PermissionUrlsSetUnrestrictedModeParams, *, timeout: float | None = None) -> PermissionsUrlsSetUnrestrictedModeResult: "Toggles the runtime's URL-permission policy between unrestricted and restricted modes.\n\nArgs:\n params: Whether the URL-permission policy should run in unrestricted mode.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionsUrlsSetUnrestrictedModeResult.from_dict(await self._client.request("session.permissions.urls.setUnrestrictedMode", params_dict, **_timeout_kwargs(timeout))) class PermissionsApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id - self.paths = PermissionsPathsApi(client, session_id) - self.urls = PermissionsUrlsApi(client, session_id) + self._assert_active = assert_active + self.paths = PermissionsPathsApi(client, session_id, assert_active) + self.urls = PermissionsUrlsApi(client, session_id, assert_active) - async def configure(self, params: PermissionsConfigureParams, *, timeout: float | None = None) -> PermissionsConfigureResult: + async def configure(self, params: PermissionsConfigureParams | None = None, *, timeout: float | None = None) -> PermissionsConfigureResult: "Replaces selected permission policy fields (rules, paths, URLs, exclusions, allow-all flags) on the session.\n\nArgs:\n params: Patch of permission policy fields to apply (omit a field to leave it unchanged).\n\nReturns:\n Indicates whether the operation succeeded." - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id return PermissionsConfigureResult.from_dict(await self._client.request("session.permissions.configure", params_dict, **_timeout_kwargs(timeout))) async def handle_pending_permission_request(self, params: PermissionDecisionRequest, *, timeout: float | None = None) -> PermissionRequestResult: "Provides a decision for a pending tool permission request.\n\nArgs:\n params: Pending permission request ID and the decision to apply (approve/reject and scope).\n\nReturns:\n Indicates whether the permission decision was applied; false when the request was already resolved." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionRequestResult.from_dict(await self._client.request("session.permissions.handlePendingPermissionRequest", params_dict, **_timeout_kwargs(timeout))) async def pending_requests(self, *, timeout: float | None = None) -> PendingPermissionRequestList: "Reconstructs the set of pending tool permission requests from the session's event history.\n\nReturns:\n List of pending permission requests reconstructed from event history." + if self._assert_active is not None: + self._assert_active() + return PendingPermissionRequestList.from_dict(await self._client.request("session.permissions.pendingRequests", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def set_approve_all(self, params: PermissionsSetApproveAllRequest, *, timeout: float | None = None) -> PermissionsSetApproveAllResult: "Enables or disables automatic approval of tool permission requests for the session.\n\nArgs:\n params: Allow-all toggle for tool permission requests, with an optional telemetry source.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionsSetApproveAllResult.from_dict(await self._client.request("session.permissions.setApproveAll", params_dict, **_timeout_kwargs(timeout))) async def modify_rules(self, params: PermissionsModifyRulesParams, *, timeout: float | None = None) -> PermissionsModifyRulesResult: "Adds or removes session-scoped or location-scoped permission rules.\n\nArgs:\n params: Scope and add/remove instructions for modifying session- or location-scoped permission rules.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionsModifyRulesResult.from_dict(await self._client.request("session.permissions.modifyRules", params_dict, **_timeout_kwargs(timeout))) async def set_required(self, params: PermissionsSetRequiredRequest, *, timeout: float | None = None) -> PermissionsSetRequiredResult: "Sets whether the client wants permission prompts bridged into session events.\n\nArgs:\n params: Toggles whether permission prompts should be bridged into session events for this client.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionsSetRequiredResult.from_dict(await self._client.request("session.permissions.setRequired", params_dict, **_timeout_kwargs(timeout))) async def reset_session_approvals(self, *, timeout: float | None = None) -> PermissionsResetSessionApprovalsResult: "Clears session-scoped tool permission approvals.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + return PermissionsResetSessionApprovalsResult.from_dict(await self._client.request("session.permissions.resetSessionApprovals", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def notify_prompt_shown(self, params: PermissionPromptShownNotification, *, timeout: float | None = None) -> PermissionsNotifyPromptShownResult: "Notifies the runtime that a permission prompt UI has been shown to the user.\n\nArgs:\n params: Notification payload describing the permission prompt that the client just rendered.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return PermissionsNotifyPromptShownResult.from_dict(await self._client.request("session.permissions.notifyPromptShown", params_dict, **_timeout_kwargs(timeout))) @@ -16415,56 +16856,94 @@ async def notify_prompt_shown(self, params: PermissionPromptShownNotification, * # Experimental: this API group is experimental and may change or be removed. class MetadataApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def snapshot(self, *, timeout: float | None = None) -> SessionMetadataSnapshot: "Returns a snapshot of the session's identifying metadata, mode, agent, and remote info.\n\nReturns:\n Point-in-time snapshot of slow-changing session identifier and state fields" + if self._assert_active is not None: + self._assert_active() + return SessionMetadataSnapshot.from_dict(await self._client.request("session.metadata.snapshot", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def is_processing(self, *, timeout: float | None = None) -> MetadataIsProcessingResult: "Reports whether the local session is currently processing user/agent messages.\n\nReturns:\n Indicates whether the local session is currently processing a turn or background continuation." + if self._assert_active is not None: + self._assert_active() + return MetadataIsProcessingResult.from_dict(await self._client.request("session.metadata.isProcessing", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def context_info(self, params: MetadataContextInfoRequest, *, timeout: float | None = None) -> MetadataContextInfoResult: "Returns the token breakdown for the session's current context window for a given model.\n\nArgs:\n params: Model identifier and token limits used to compute the context-info breakdown.\n\nReturns:\n Token breakdown for the session's current context window, or null if uninitialized." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return MetadataContextInfoResult.from_dict(await self._client.request("session.metadata.contextInfo", params_dict, **_timeout_kwargs(timeout))) async def record_context_change(self, params: MetadataRecordContextChangeRequest, *, timeout: float | None = None) -> MetadataRecordContextChangeResult: "Records a working-directory/git context change and emits a `session.context_changed` event.\n\nArgs:\n params: Updated working-directory/git context to record on the session.\n\nReturns:\n Notify the session that its working directory context has changed. Emits a `session.context_changed` event so consumers (telemetry, OTel tracker, ACP, the timeline UI) can react. Use this when the host has detected a cwd/branch/repo change outside the session's normal lifecycle (e.g., after a shell command in interactive mode)." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return MetadataRecordContextChangeResult.from_dict(await self._client.request("session.metadata.recordContextChange", params_dict, **_timeout_kwargs(timeout))) async def set_working_directory(self, params: MetadataSetWorkingDirectoryRequest, *, timeout: float | None = None) -> MetadataSetWorkingDirectoryResult: "Updates the session's recorded working directory.\n\nArgs:\n params: Absolute path to set as the session's new working directory.\n\nReturns:\n Update the session's working directory. Used by the host when the user explicitly changes cwd (e.g., the `/cd` slash command). The host is responsible for `process.chdir` and any related side-effects (file index, etc.); this method only updates the session's own recorded path." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return MetadataSetWorkingDirectoryResult.from_dict(await self._client.request("session.metadata.setWorkingDirectory", params_dict, **_timeout_kwargs(timeout))) async def recompute_context_tokens(self, params: MetadataRecomputeContextTokensRequest, *, timeout: float | None = None) -> MetadataRecomputeContextTokensResult: "Re-tokenizes the session's existing messages against a model and returns aggregate token totals.\n\nArgs:\n params: Model identifier to use when re-tokenizing the session's existing messages.\n\nReturns:\n Re-tokenize the session's existing messages against `modelId` and return the token totals. Useful for hosts that want an initial estimate of context usage on session resume, before the next agent turn fires `session.context_info_changed` events. Returns zeros for an empty session." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return MetadataRecomputeContextTokensResult.from_dict(await self._client.request("session.metadata.recomputeContextTokens", params_dict, **_timeout_kwargs(timeout))) class ShellApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def exec(self, params: ShellExecRequest, *, timeout: float | None = None) -> ShellExecResult: "Starts a shell command and streams output through session notifications.\n\nArgs:\n params: Shell command to run, with optional working directory and timeout in milliseconds.\n\nReturns:\n Identifier of the spawned process, used to correlate streamed output and exit notifications." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return ShellExecResult.from_dict(await self._client.request("session.shell.exec", params_dict, **_timeout_kwargs(timeout))) async def kill(self, params: ShellKillRequest, *, timeout: float | None = None) -> ShellKillResult: "Sends a signal to a shell process previously started via \"shell.exec\".\n\nArgs:\n params: Identifier of a process previously returned by \"shell.exec\" and the signal to send.\n\nReturns:\n Indicates whether the signal was delivered; false if the process was unknown or already exited." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return ShellKillResult.from_dict(await self._client.request("session.shell.kill", params_dict, **_timeout_kwargs(timeout))) @@ -16472,76 +16951,121 @@ async def kill(self, params: ShellKillRequest, *, timeout: float | None = None) # Experimental: this API group is experimental and may change or be removed. class HistoryApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def compact(self, *, timeout: float | None = None) -> HistoryCompactResult: "Compacts the session history to reduce context usage.\n\nReturns:\n Compaction outcome with the number of tokens and messages removed, summary text, and the resulting context window breakdown." + if self._assert_active is not None: + self._assert_active() + return HistoryCompactResult.from_dict(await self._client.request("session.history.compact", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def truncate(self, params: HistoryTruncateRequest, *, timeout: float | None = None) -> HistoryTruncateResult: "Truncates persisted session history to a specific event.\n\nArgs:\n params: Identifier of the event to truncate to; this event and all later events are removed.\n\nReturns:\n Number of events that were removed by the truncation." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return HistoryTruncateResult.from_dict(await self._client.request("session.history.truncate", params_dict, **_timeout_kwargs(timeout))) async def cancel_background_compaction(self, *, timeout: float | None = None) -> HistoryCancelBackgroundCompactionResult: "Cancels any in-progress background compaction on a local session.\n\nReturns:\n Indicates whether an in-progress background compaction was cancelled." + if self._assert_active is not None: + self._assert_active() + return HistoryCancelBackgroundCompactionResult.from_dict(await self._client.request("session.history.cancelBackgroundCompaction", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def abort_manual_compaction(self, *, timeout: float | None = None) -> HistoryAbortManualCompactionResult: "Aborts any in-progress manual compaction on a local session.\n\nReturns:\n Indicates whether an in-progress manual compaction was aborted." + if self._assert_active is not None: + self._assert_active() + return HistoryAbortManualCompactionResult.from_dict(await self._client.request("session.history.abortManualCompaction", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def summarize_for_handoff(self, *, timeout: float | None = None) -> HistorySummarizeForHandoffResult: "Produces a markdown summary of the session's conversation context for hand-off scenarios.\n\nReturns:\n Markdown summary of the conversation context (empty when not available)." + if self._assert_active is not None: + self._assert_active() + return HistorySummarizeForHandoffResult.from_dict(await self._client.request("session.history.summarizeForHandoff", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) # Experimental: this API group is experimental and may change or be removed. class QueueApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def pending_items(self, *, timeout: float | None = None) -> QueuePendingItemsResult: "Returns the local session's pending user-facing queued items and steering messages.\n\nReturns:\n Snapshot of the session's pending queued items and immediate-steering messages." + if self._assert_active is not None: + self._assert_active() + return QueuePendingItemsResult.from_dict(await self._client.request("session.queue.pendingItems", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def remove_most_recent(self, *, timeout: float | None = None) -> QueueRemoveMostRecentResult: "Removes the most recently queued user-facing item (LIFO).\n\nReturns:\n Indicates whether a user-facing pending item was removed." + if self._assert_active is not None: + self._assert_active() + return QueueRemoveMostRecentResult.from_dict(await self._client.request("session.queue.removeMostRecent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def clear(self, *, timeout: float | None = None) -> None: "Clears all pending queued items on the local session." + if self._assert_active is not None: + self._assert_active() + await self._client.request("session.queue.clear", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) # Experimental: this API group is experimental and may change or be removed. class EventLogApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active - async def read(self, params: EventLogReadRequest, *, timeout: float | None = None) -> EventsReadResult: + async def read(self, params: EventLogReadRequest | None = None, *, timeout: float | None = None) -> EventsReadResult: "Reads a batch of session events from a cursor, optionally waiting for new events.\n\nArgs:\n params: Cursor, batch size, and optional long-poll/filter parameters for reading session events.\n\nReturns:\n Batch of session events returned by a read, with cursor and continuation metadata." - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id return EventsReadResult.from_dict(await self._client.request("session.eventLog.read", params_dict, **_timeout_kwargs(timeout))) async def tail(self, *, timeout: float | None = None) -> EventLogTailResult: "Returns a snapshot of the current tail cursor without consuming events.\n\nReturns:\n Snapshot of the current tail cursor without returning any events. Use this when a consumer wants to subscribe to live events going forward without first paginating through the entire persisted history (which would happen if `read` were called without a cursor on a long-lived session)." + if self._assert_active is not None: + self._assert_active() + return EventLogTailResult.from_dict(await self._client.request("session.eventLog.tail", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def register_interest(self, params: RegisterEventInterestParams, *, timeout: float | None = None) -> RegisterEventInterestResult: "Registers consumer interest in an event type for runtime gating purposes.\n\nArgs:\n params: Event type to register consumer interest for, used by runtime gating logic.\n\nReturns:\n Opaque handle representing an event-type interest registration." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return RegisterEventInterestResult.from_dict(await self._client.request("session.eventLog.registerInterest", params_dict, **_timeout_kwargs(timeout))) async def release_interest(self, params: ReleaseEventInterestParams, *, timeout: float | None = None) -> EventLogReleaseInterestResult: "Releases a consumer's previously-registered interest in an event type.\n\nArgs:\n params: Opaque handle previously returned by `registerInterest` to release.\n\nReturns:\n Indicates whether the operation succeeded." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return EventLogReleaseInterestResult.from_dict(await self._client.request("session.eventLog.releaseInterest", params_dict, **_timeout_kwargs(timeout))) @@ -16549,33 +17073,49 @@ async def release_interest(self, params: ReleaseEventInterestParams, *, timeout: # Experimental: this API group is experimental and may change or be removed. class UsageApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def get_metrics(self, *, timeout: float | None = None) -> UsageGetMetricsResult: "Gets accumulated usage metrics for the session.\n\nReturns:\n Accumulated session usage metrics, including premium request cost, token counts, model breakdown, and code-change totals." + if self._assert_active is not None: + self._assert_active() + return UsageGetMetricsResult.from_dict(await self._client.request("session.usage.getMetrics", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) # Experimental: this API group is experimental and may change or be removed. class RemoteApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active - async def enable(self, params: RemoteEnableRequest, *, timeout: float | None = None) -> RemoteEnableResult: + async def enable(self, params: RemoteEnableRequest | None = None, *, timeout: float | None = None) -> RemoteEnableResult: "Enables remote session export or steering.\n\nArgs:\n params: Optional remote session mode (\"off\", \"export\", or \"on\"); defaults to enabling both export and remote steering.\n\nReturns:\n GitHub URL for the session and a flag indicating whether remote steering is enabled." - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id return RemoteEnableResult.from_dict(await self._client.request("session.remote.enable", params_dict, **_timeout_kwargs(timeout))) async def disable(self, *, timeout: float | None = None) -> None: "Disables remote session export and steering." + if self._assert_active is not None: + self._assert_active() + await self._client.request("session.remote.disable", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) async def notify_steerable_changed(self, params: RemoteNotifySteerableChangedRequest, *, timeout: float | None = None) -> RemoteNotifySteerableChangedResult: "Persists a remote-steerability change emitted by the host as a session event.\n\nArgs:\n params: New remote-steerability state to persist as a `session.remote_steerable_changed` event.\n\nReturns:\n Persist a steerability change as a `session.remote_steerable_changed` event. Used by the host (CLI / SDK consumer) when it has just finished enabling or disabling steering on a remote exporter that the runtime does not directly own." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return RemoteNotifySteerableChangedResult.from_dict(await self._client.request("session.remote.notifySteerableChanged", params_dict, **_timeout_kwargs(timeout))) @@ -16583,16 +17123,25 @@ async def notify_steerable_changed(self, params: RemoteNotifySteerableChangedReq # Experimental: this API group is experimental and may change or be removed. class ScheduleApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id + self._assert_active = assert_active async def list(self, *, timeout: float | None = None) -> ScheduleList: "Lists the session's currently active scheduled prompts.\n\nReturns:\n Snapshot of the currently active recurring prompts for this session." + if self._assert_active is not None: + self._assert_active() + return ScheduleList.from_dict(await self._client.request("session.schedule.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout))) async def stop(self, params: ScheduleStopRequest, *, timeout: float | None = None) -> ScheduleStopResult: "Removes a scheduled prompt by id.\n\nArgs:\n params: Identifier of the scheduled prompt to remove.\n\nReturns:\n Remove a scheduled prompt by id. The result entry is omitted if the id was unknown." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return ScheduleStopResult.from_dict(await self._client.request("session.schedule.stop", params_dict, **_timeout_kwargs(timeout))) @@ -16600,63 +17149,83 @@ async def stop(self, params: ScheduleStopRequest, *, timeout: float | None = Non class SessionRpc: """Typed session-scoped RPC methods.""" - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None): self._client = client self._session_id = session_id - self.auth = AuthApi(client, session_id) - self.model = ModelApi(client, session_id) - self.mode = ModeApi(client, session_id) - self.name = NameApi(client, session_id) - self.plan = PlanApi(client, session_id) - self.workspaces = WorkspacesApi(client, session_id) - self.instructions = InstructionsApi(client, session_id) - self.fleet = FleetApi(client, session_id) - self.agent = AgentApi(client, session_id) - self.tasks = TasksApi(client, session_id) - self.skills = SkillsApi(client, session_id) - self.mcp = McpApi(client, session_id) - self.plugins = PluginsApi(client, session_id) - self.options = OptionsApi(client, session_id) - self.lsp = LspApi(client, session_id) - self.extensions = ExtensionsApi(client, session_id) - self.tools = ToolsApi(client, session_id) - self.commands = CommandsApi(client, session_id) - self.telemetry = TelemetryApi(client, session_id) - self.ui = UiApi(client, session_id) - self.permissions = PermissionsApi(client, session_id) - self.metadata = MetadataApi(client, session_id) - self.shell = ShellApi(client, session_id) - self.history = HistoryApi(client, session_id) - self.queue = QueueApi(client, session_id) - self.event_log = EventLogApi(client, session_id) - self.usage = UsageApi(client, session_id) - self.remote = RemoteApi(client, session_id) - self.schedule = ScheduleApi(client, session_id) + self._assert_active = assert_active + self.auth = AuthApi(client, session_id, assert_active) + self.model = ModelApi(client, session_id, assert_active) + self.mode = ModeApi(client, session_id, assert_active) + self.name = NameApi(client, session_id, assert_active) + self.plan = PlanApi(client, session_id, assert_active) + self.workspaces = WorkspacesApi(client, session_id, assert_active) + self.instructions = InstructionsApi(client, session_id, assert_active) + self.fleet = FleetApi(client, session_id, assert_active) + self.agent = AgentApi(client, session_id, assert_active) + self.tasks = TasksApi(client, session_id, assert_active) + self.skills = SkillsApi(client, session_id, assert_active) + self.mcp = McpApi(client, session_id, assert_active) + self.plugins = PluginsApi(client, session_id, assert_active) + self.options = OptionsApi(client, session_id, assert_active) + self.lsp = LspApi(client, session_id, assert_active) + self.extensions = ExtensionsApi(client, session_id, assert_active) + self.tools = ToolsApi(client, session_id, assert_active) + self.commands = CommandsApi(client, session_id, assert_active) + self.telemetry = TelemetryApi(client, session_id, assert_active) + self.ui = UiApi(client, session_id, assert_active) + self.permissions = PermissionsApi(client, session_id, assert_active) + self.metadata = MetadataApi(client, session_id, assert_active) + self.shell = ShellApi(client, session_id, assert_active) + self.history = HistoryApi(client, session_id, assert_active) + self.queue = QueueApi(client, session_id, assert_active) + self.event_log = EventLogApi(client, session_id, assert_active) + self.usage = UsageApi(client, session_id, assert_active) + self.remote = RemoteApi(client, session_id, assert_active) + self.schedule = ScheduleApi(client, session_id, assert_active) async def suspend(self, *, timeout: float | None = None) -> None: "Suspends the session while preserving persisted state for later resume." + if self._assert_active is not None: + self._assert_active() + await self._client.request("session.suspend", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) async def send(self, params: SendRequest, *, timeout: float | None = None) -> SendResult: "Sends a user message to the session and returns its message ID.\n\nArgs:\n params: Parameters for sending a user message to the session\n\nReturns:\n Result of sending a user message" + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return SendResult.from_dict(await self._client.request("session.send", params_dict, **_timeout_kwargs(timeout))) - async def abort(self, params: AbortRequest, *, timeout: float | None = None) -> AbortResult: + async def abort(self, params: AbortRequest | None = None, *, timeout: float | None = None) -> AbortResult: "Aborts the current agent turn.\n\nArgs:\n params: Parameters for aborting the current turn\n\nReturns:\n Result of aborting the current turn" - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id return AbortResult.from_dict(await self._client.request("session.abort", params_dict, **_timeout_kwargs(timeout))) - async def shutdown(self, params: ShutdownRequest, *, timeout: float | None = None) -> None: + async def shutdown(self, params: ShutdownRequest | None = None, *, timeout: float | None = None) -> None: "Shuts down the session and persists its final state. Awaits any deferred sessionEnd hooks before resolving so user-supplied hook scripts complete before the runtime tears down.\n\nArgs:\n params: Parameters for shutting down the session" - params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + if self._assert_active is not None: + self._assert_active() + + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} if params is not None else {} params_dict["sessionId"] = self._session_id await self._client.request("session.shutdown", params_dict, **_timeout_kwargs(timeout)) async def log(self, params: LogRequest, *, timeout: float | None = None) -> LogResult: "Emits a user-visible session log event.\n\nArgs:\n params: Message text, optional severity level, persistence flag, optional follow-up URL, and optional tip.\n\nReturns:\n Identifier of the session event that was emitted for the log message." + if self._assert_active is not None: + self._assert_active() + if params is None: + raise TypeError("params is required") + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id return LogResult.from_dict(await self._client.request("session.log", params_dict, **_timeout_kwargs(timeout))) diff --git a/python/copilot/session.py b/python/copilot/session.py index 4789724fb..172da7062 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -1116,7 +1116,11 @@ class CopilotSession: """ def __init__( - self, session_id: str, client: Any, workspace_path: os.PathLike[str] | str | None = None + self, + session_id: str, + client: Any, + workspace_path: os.PathLike[str] | str | None = None, + on_disconnected: Callable[[CopilotSession], None] | None = None, ): """ Initialize a new CopilotSession. @@ -1130,6 +1134,8 @@ def __init__( client: The internal client connection to the Copilot CLI. workspace_path: Path to the session workspace directory (when infinite sessions enabled). + on_disconnected: Internal-only callback invoked when the owning + client should unregister this disconnected session. """ self.session_id = session_id self._client = client @@ -1158,12 +1164,16 @@ def __init__( self._client_session_apis = ClientSessionApiHandlers() self._rpc: SessionRpc | None = None self._destroyed = False + self._disconnect_task: asyncio.Task[None] | None = None + self._disconnect_lock = threading.Lock() + self._on_disconnected = on_disconnected @property def rpc(self) -> SessionRpc: """Typed session-scoped RPC methods.""" + self._assert_not_destroyed() if self._rpc is None: - self._rpc = SessionRpc(self._client, self.session_id) + self._rpc = SessionRpc(self._client, self.session_id, self._assert_not_destroyed) return self._rpc @property @@ -1186,6 +1196,7 @@ def ui(self) -> SessionUiApi: >>> if ui_caps.get("elicitation"): ... ok = await session.ui.confirm("Deploy to production?") """ + self._assert_not_destroyed() return SessionUiApi(self) @functools.cached_property @@ -1235,6 +1246,7 @@ async def send( ... attachments=[{"type": "file", "path": "./src/main.py"}], ... ) """ + self._assert_not_destroyed() params: dict[str, Any] = { "sessionId": self.session_id, "prompt": prompt, @@ -1301,6 +1313,7 @@ async def send_and_wait( ... case AssistantMessageData() as data: ... print(data.content) """ + self._assert_not_destroyed() total_start = time.perf_counter() idle_event = asyncio.Event() error_event: Exception | None = None @@ -1403,6 +1416,7 @@ def on(self, handler: Callable[[SessionEvent], None]) -> Callable[[], None]: >>> # Later, to stop receiving events: >>> unsubscribe() """ + self._assert_not_destroyed() with self._event_handlers_lock: self._event_handlers.add(handler) @@ -1838,6 +1852,7 @@ async def _handle_elicitation_request( def _assert_elicitation(self) -> None: """Raises if the host does not support elicitation.""" + self._assert_not_destroyed() ui_caps = self._capabilities.get("ui", {}) if not ui_caps.get("elicitation"): raise RuntimeError( @@ -2235,6 +2250,7 @@ async def get_messages(self) -> list[SessionEvent]: ... case AssistantMessageData() as data: ... print(f"Assistant: {data.content}") """ + self._assert_not_destroyed() response = await self._client.request("session.getMessages", {"sessionId": self.session_id}) # Convert dict events to SessionEvent objects events_dicts = response["events"] @@ -2263,31 +2279,54 @@ async def disconnect(self) -> None: >>> # Clean up when done — session can still be resumed later >>> await session.disconnect() """ - # Ensure that the check and update of _destroyed are atomic so that - # only the first caller proceeds to send the destroy RPC. - with self._event_handlers_lock: + with self._disconnect_lock: if self._destroyed: return - self._destroyed = True + if self._disconnect_task is None: + self._disconnect_task = asyncio.create_task(self._disconnect_core()) + disconnect_task = self._disconnect_task + await asyncio.shield(disconnect_task) + + async def _disconnect_core(self) -> None: try: await self._client.request("session.destroy", {"sessionId": self.session_id}) finally: - # Clear handlers even if the request fails. - with self._event_handlers_lock: - self._event_handlers.clear() - with self._tool_handlers_lock: - self._tool_handlers.clear() - with self._permission_handler_lock: - self._permission_handler = None - with self._command_handlers_lock: - self._command_handlers.clear() - with self._elicitation_handler_lock: - self._elicitation_handler = None - with self._exit_plan_mode_handler_lock: - self._exit_plan_mode_handler = None - with self._auto_mode_switch_handler_lock: - self._auto_mode_switch_handler = None + self._mark_disconnected() + + def _assert_not_destroyed(self) -> None: + if self._destroyed: + raise RuntimeError("Session has been disconnected.") + + def _mark_disconnected(self) -> None: + with self._disconnect_lock: + if self._destroyed: + return + self._destroyed = True + + self._rpc = None + with self._event_handlers_lock: + self._event_handlers.clear() + with self._tool_handlers_lock: + self._tool_handlers.clear() + with self._permission_handler_lock: + self._permission_handler = None + with self._user_input_handler_lock: + self._user_input_handler = None + with self._command_handlers_lock: + self._command_handlers.clear() + with self._elicitation_handler_lock: + self._elicitation_handler = None + with self._exit_plan_mode_handler_lock: + self._exit_plan_mode_handler = None + with self._auto_mode_switch_handler_lock: + self._auto_mode_switch_handler = None + with self._hooks_lock: + self._hooks = None + with self._transform_callbacks_lock: + self._transform_callbacks = None + if self._on_disconnected is not None: + self._on_disconnected(self) async def destroy(self) -> None: """ @@ -2346,6 +2385,7 @@ async def abort(self) -> None: >>> await asyncio.sleep(5) >>> await session.abort() """ + self._assert_not_destroyed() await self._client.request("session.abort", {"sessionId": self.session_id}) async def set_model( @@ -2374,6 +2414,7 @@ async def set_model( >>> await session.set_model("gpt-4.1") >>> await session.set_model("claude-sonnet-4.6", reasoning_effort="high") """ + self._assert_not_destroyed() rpc_caps = None if model_capabilities is not None: from .client import _capabilities_to_dict @@ -2416,6 +2457,7 @@ async def log( >>> await session.log("Operation failed", level="error") >>> await session.log("Temporary status update", ephemeral=True) """ + self._assert_not_destroyed() params = LogRequest( message=message, level=SessionLogLevel(level) if level is not None else None, diff --git a/python/e2e/test_client_e2e.py b/python/e2e/test_client_e2e.py index fc7315a58..b7e14ba98 100644 --- a/python/e2e/test_client_e2e.py +++ b/python/e2e/test_client_e2e.py @@ -13,7 +13,7 @@ ) from copilot.session import PermissionHandler -from .testharness import CLI_PATH +from .testharness import CLI_PATH, mark_inactive_for_resume class TestClient: @@ -270,6 +270,7 @@ async def test_should_resume_session_without_permission_handler(self): try: await client.start() session = await client.create_session() + mark_inactive_for_resume(session) resumed = await client.resume_session(session.session_id) assert resumed.session_id == session.session_id diff --git a/python/e2e/test_commands_e2e.py b/python/e2e/test_commands_e2e.py index a1c44b7b3..1a7784f35 100644 --- a/python/e2e/test_commands_e2e.py +++ b/python/e2e/test_commands_e2e.py @@ -20,6 +20,7 @@ from copilot.session import CommandDefinition, PermissionHandler from .testharness.context import SNAPSHOTS_DIR, get_cli_path_for_tests +from .testharness.helper import mark_inactive_for_resume from .testharness.proxy import CapiProxy pytestmark = pytest.mark.asyncio(loop_scope="module") @@ -258,6 +259,7 @@ async def test_session_with_commands_resumes_successfully(self, ctx): ) session_id = session1.session_id + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all, diff --git a/python/e2e/test_mcp_and_agents_e2e.py b/python/e2e/test_mcp_and_agents_e2e.py index 5d1275ad6..df27c2ecf 100644 --- a/python/e2e/test_mcp_and_agents_e2e.py +++ b/python/e2e/test_mcp_and_agents_e2e.py @@ -8,7 +8,7 @@ from copilot.session import CustomAgentConfig, MCPServerConfig, PermissionHandler -from .testharness import E2ETestContext, get_final_assistant_message +from .testharness import E2ETestContext, get_final_assistant_message, mark_inactive_for_resume TEST_MCP_SERVER = str( (Path(__file__).parents[2] / "test" / "harness" / "test-mcp-server.mjs").resolve() @@ -64,6 +64,7 @@ async def test_should_accept_mcp_server_configuration_on_session_resume( } } + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all, @@ -157,6 +158,7 @@ async def test_should_accept_custom_agent_configuration_on_session_resume( } ] + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all, diff --git a/python/e2e/test_permissions_e2e.py b/python/e2e/test_permissions_e2e.py index 46cf2f3d4..5a62342ce 100644 --- a/python/e2e/test_permissions_e2e.py +++ b/python/e2e/test_permissions_e2e.py @@ -13,7 +13,7 @@ ) from copilot.session import PermissionHandler, PermissionRequestResult -from .testharness import E2ETestContext +from .testharness import E2ETestContext, mark_inactive_for_resume from .testharness.helper import read_file, write_file pytestmark = pytest.mark.asyncio(loop_scope="module") @@ -140,6 +140,7 @@ async def test_should_deny_tool_operations_when_handler_explicitly_denies_after_ def deny_all(request, invocation): return PermissionRequestResult() + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session(session_id, on_permission_request=deny_all) denied_events = [] @@ -218,6 +219,7 @@ def on_permission_request( permission_requests.append(request) return PermissionRequestResult(kind="approve-once") + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=on_permission_request ) diff --git a/python/e2e/test_session_config_e2e.py b/python/e2e/test_session_config_e2e.py index 1fd2cd0a2..b69893bc5 100644 --- a/python/e2e/test_session_config_e2e.py +++ b/python/e2e/test_session_config_e2e.py @@ -9,7 +9,7 @@ from copilot import ModelCapabilitiesOverride, ModelSupportsOverride from copilot.session import PermissionHandler -from .testharness import E2ETestContext +from .testharness import E2ETestContext, mark_inactive_for_resume pytestmark = pytest.mark.asyncio(loop_scope="module") @@ -217,6 +217,7 @@ async def test_should_forward_custom_provider_headers_on_resume(self, ctx: E2ETe ) session_id = session1.session_id + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all, @@ -314,6 +315,7 @@ async def test_should_apply_workingdirectory_on_session_resume(self, ctx: E2ETes ) session_id = session1.session_id + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all, @@ -335,6 +337,7 @@ async def test_should_apply_systemmessage_on_session_resume(self, ctx: E2ETestCo session_id = session1.session_id resume_instruction = "End the response with RESUME_SYSTEM_MESSAGE_SENTINEL." + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all, @@ -394,6 +397,7 @@ async def test_should_apply_instruction_directories_on_resume(self, ctx: E2ETest working_directory=project_dir, ) + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session1.session_id, on_permission_request=PermissionHandler.approve_all, @@ -416,6 +420,7 @@ async def test_should_apply_availabletools_on_session_resume(self, ctx: E2ETestC ) session_id = session1.session_id + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all, diff --git a/python/e2e/test_session_e2e.py b/python/e2e/test_session_e2e.py index 062ce8d58..92c8e0aef 100644 --- a/python/e2e/test_session_e2e.py +++ b/python/e2e/test_session_e2e.py @@ -11,7 +11,12 @@ from copilot.session import PermissionHandler from copilot.tools import Tool, ToolResult -from .testharness import E2ETestContext, get_final_assistant_message, get_next_event_of_type +from .testharness import ( + E2ETestContext, + get_final_assistant_message, + get_next_event_of_type, + mark_inactive_for_resume, +) pytestmark = pytest.mark.asyncio(loop_scope="module") @@ -31,7 +36,7 @@ async def test_should_create_and_disconnect_sessions(self, ctx: E2ETestContext): await session.disconnect() - with pytest.raises(Exception, match="Session not found"): + with pytest.raises(Exception, match="Session has been disconnected"): await session.get_messages() async def test_should_have_stateful_conversation(self, ctx: E2ETestContext): @@ -216,6 +221,7 @@ async def test_should_resume_a_session_using_the_same_client(self, ctx: E2ETestC assert "2" in answer.data.content # Resume using the same client + mark_inactive_for_resume(session1) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all ) @@ -457,6 +463,7 @@ async def test_should_resume_session_with_custom_provider(self, ctx: E2ETestCont session_id = session.session_id # Resume the session with a provider + mark_inactive_for_resume(session) session2 = await ctx.client.resume_session( session_id, on_permission_request=PermissionHandler.approve_all, diff --git a/python/e2e/testharness/__init__.py b/python/e2e/testharness/__init__.py index 28558d687..8d74f6873 100644 --- a/python/e2e/testharness/__init__.py +++ b/python/e2e/testharness/__init__.py @@ -1,7 +1,7 @@ """Test harness for E2E tests.""" from .context import CLI_PATH, DEFAULT_GITHUB_TOKEN, E2ETestContext -from .helper import get_final_assistant_message, get_next_event_of_type +from .helper import get_final_assistant_message, get_next_event_of_type, mark_inactive_for_resume from .proxy import CapiProxy __all__ = [ @@ -11,4 +11,5 @@ "CapiProxy", "get_final_assistant_message", "get_next_event_of_type", + "mark_inactive_for_resume", ] diff --git a/python/e2e/testharness/helper.py b/python/e2e/testharness/helper.py index c603a8ec5..b9ca78aca 100644 --- a/python/e2e/testharness/helper.py +++ b/python/e2e/testharness/helper.py @@ -139,6 +139,11 @@ def read_file(work_dir: str, filename: str) -> str: return f.read() +def mark_inactive_for_resume(session: CopilotSession) -> None: + """Clear local active-session tracking without destroying the server session.""" + session._mark_disconnected() + + async def get_next_event_of_type(session: CopilotSession, event_type: str, timeout: float = 30.0): """ Wait for and return the next event of a specific type from a session. diff --git a/python/test_client.py b/python/test_client.py index c03968c55..e42dbf02f 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -19,7 +19,7 @@ ModelSupports, SubprocessConfig, ) -from copilot.session import PermissionHandler, PermissionRequestResult +from copilot.session import CopilotSession, PermissionHandler, PermissionRequestResult from e2e.testharness import CLI_PATH @@ -72,6 +72,7 @@ async def test_resume_session_allows_none_permission_handler(self): session = await client.create_session( on_permission_request=PermissionHandler.approve_all ) + session._mark_disconnected() resumed = await client.resume_session(session.session_id, on_permission_request=None) assert resumed.session_id == session.session_id finally: @@ -115,6 +116,200 @@ async def mock_request(method, params): await client.force_stop() +class TestSessionLifecycle: + @pytest.mark.asyncio + async def test_direct_disconnect_unregisters_after_destroy_request(self): + sessions = {} + + async def request(method, params): + if method == "session.destroy": + assert sessions["session-1"] is session + return {} + + rpc_client = AsyncMock() + rpc_client.request = AsyncMock(side_effect=request) + + def unregister(candidate): + if sessions.get(candidate.session_id) is candidate: + del sessions[candidate.session_id] + + session = CopilotSession( + "session-1", + rpc_client, + on_disconnected=unregister, + ) + sessions[session.session_id] = session + + await session.disconnect() + + rpc_client.request.assert_awaited_with("session.destroy", {"sessionId": "session-1"}) + assert "session-1" not in sessions + with pytest.raises(RuntimeError, match="disconnected"): + await session.send("hello") + + def test_stale_session_disconnect_does_not_unregister_replacement(self): + sessions = {} + + def unregister(candidate): + if sessions.get(candidate.session_id) is candidate: + del sessions[candidate.session_id] + + rpc_client = AsyncMock() + stale = CopilotSession("session-1", rpc_client, on_disconnected=unregister) + replacement = CopilotSession("session-1", rpc_client, on_disconnected=unregister) + sessions["session-1"] = replacement + + stale._mark_disconnected() + + assert sessions["session-1"] is replacement + replacement._mark_disconnected() + + def test_rejects_duplicate_active_session_registration(self): + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH, log_level="error")) + rpc_client = AsyncMock() + first = CopilotSession("session-1", rpc_client) + second = CopilotSession("session-1", rpc_client) + + client._register_session(first) + + with pytest.raises(RuntimeError, match="already active"): + client._register_session(second) + + first._mark_disconnected() + + @pytest.mark.asyncio + async def test_failed_create_session_fs_setup_marks_session_disconnected(self): + client = CopilotClient( + SubprocessConfig( + cli_path=CLI_PATH, + log_level="error", + session_fs={ + "initial_cwd": "/", + "session_state_path": "/session-state", + "conventions": "posix", + }, + ) + ) + rpc_client = AsyncMock() + rpc_client.request = AsyncMock() + client._client = rpc_client + captured = {} + + def failing_handler(session): + captured["session"] = session + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + await client.create_session(create_session_fs_handler=failing_handler) + + session = captured["session"] + assert session.session_id not in client._sessions + with pytest.raises(RuntimeError, match="disconnected"): + await session.send("hello") + rpc_client.request.assert_not_called() + + @pytest.mark.asyncio + async def test_failed_resume_session_fs_setup_marks_session_disconnected(self): + client = CopilotClient( + SubprocessConfig( + cli_path=CLI_PATH, + log_level="error", + session_fs={ + "initial_cwd": "/", + "session_state_path": "/session-state", + "conventions": "posix", + }, + ) + ) + rpc_client = AsyncMock() + rpc_client.request = AsyncMock() + client._client = rpc_client + captured = {} + + def failing_handler(session): + captured["session"] = session + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + await client.resume_session("session-1", create_session_fs_handler=failing_handler) + + session = captured["session"] + assert session.session_id not in client._sessions + with pytest.raises(RuntimeError, match="disconnected"): + await session.send("hello") + rpc_client.request.assert_not_called() + + @pytest.mark.asyncio + async def test_duplicate_create_marks_captured_session_disconnected(self): + client = CopilotClient( + SubprocessConfig( + cli_path=CLI_PATH, + log_level="error", + session_fs={ + "initial_cwd": "/", + "session_state_path": "/session-state", + "conventions": "posix", + }, + ) + ) + rpc_client = AsyncMock() + rpc_client.request = AsyncMock() + client._client = rpc_client + existing = CopilotSession("session-1", rpc_client) + client._register_session(existing) + captured = {} + + def create_provider(session): + captured["session"] = session + return object() + + with pytest.raises(RuntimeError, match="already active"): + await client.create_session( + session_id="session-1", create_session_fs_handler=create_provider + ) + + session = captured["session"] + assert client._sessions["session-1"] is existing + with pytest.raises(RuntimeError, match="disconnected"): + await session.send("hello") + rpc_client.request.assert_not_called() + existing._mark_disconnected() + + @pytest.mark.asyncio + async def test_duplicate_resume_marks_captured_session_disconnected(self): + client = CopilotClient( + SubprocessConfig( + cli_path=CLI_PATH, + log_level="error", + session_fs={ + "initial_cwd": "/", + "session_state_path": "/session-state", + "conventions": "posix", + }, + ) + ) + rpc_client = AsyncMock() + rpc_client.request = AsyncMock() + client._client = rpc_client + existing = CopilotSession("session-1", rpc_client) + client._register_session(existing) + captured = {} + + def create_provider(session): + captured["session"] = session + return object() + + with pytest.raises(RuntimeError, match="already active"): + await client.resume_session("session-1", create_session_fs_handler=create_provider) + + session = captured["session"] + assert client._sessions["session-1"] is existing + with pytest.raises(RuntimeError, match="disconnected"): + await session.send("hello") + rpc_client.request.assert_not_called() + existing._mark_disconnected() + + class TestURLParsing: def test_parse_port_only_url(self): client = CopilotClient(ExternalServerConfig(url="8080")) @@ -338,6 +533,7 @@ async def mock_request(method, params): def grep(params) -> str: return "ok" + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -573,6 +769,7 @@ async def mock_request(method, params): return await original_request(method, params) client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -624,6 +821,7 @@ async def mock_request(method, params): return await original_request(method, params) client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -691,6 +889,7 @@ async def mock_request(method, params): return await original_request(method, params) client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -789,6 +988,7 @@ async def mock_request(method, params): return await original_request(method, params) client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -864,6 +1064,7 @@ async def mock_request(method, params): return await original_request(method, params) client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -894,6 +1095,7 @@ async def mock_request(method, params): return await original_request(method, params) client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -923,6 +1125,7 @@ async def mock_request(method, params): return await original_request(method, params) client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -952,6 +1155,7 @@ async def mock_request(method, params): return await original_request(method, params) client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, diff --git a/python/test_commands_and_elicitation.py b/python/test_commands_and_elicitation.py index 470e2f8f3..2897a0e3e 100644 --- a/python/test_commands_and_elicitation.py +++ b/python/test_commands_and_elicitation.py @@ -107,6 +107,7 @@ async def mock_request(method, params): client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, @@ -525,6 +526,7 @@ async def mock_request(method, params): client._client.request = mock_request + session._mark_disconnected() await client.resume_session( session.session_id, on_permission_request=PermissionHandler.approve_all, diff --git a/python/test_rpc_generated.py b/python/test_rpc_generated.py index 5f484add0..35d049b9a 100644 --- a/python/test_rpc_generated.py +++ b/python/test_rpc_generated.py @@ -7,6 +7,8 @@ from copilot.generated.rpc import ( CommandsApi, CommandsInvokeRequest, + ServerModelsApi, + ServerToolsApi, SlashCommandInvocationResultKind, ) @@ -22,3 +24,41 @@ async def test_commands_invoke_deserializes_slash_command_result(): assert result.kind is SlashCommandInvocationResultKind.TEXT assert result.text == "hello" assert result.markdown is True + + +@pytest.mark.asyncio +async def test_generated_rpc_rejects_missing_required_params(): + client = AsyncMock() + api = CommandsApi(client, "sess-1") + + with pytest.raises(TypeError, match="params is required"): + await api.invoke(None) # type: ignore[arg-type] + + client.request.assert_not_called() + + +@pytest.mark.asyncio +async def test_generated_rpc_allows_missing_optional_only_params(): + client = AsyncMock() + client.request = AsyncMock(side_effect=[{"models": []}, {"tools": []}]) + + await ServerModelsApi(client).list() + await ServerToolsApi(client).list() + + assert client.request.call_args_list[0].args[:2] == ("models.list", {}) + assert client.request.call_args_list[1].args[:2] == ("tools.list", {}) + + +@pytest.mark.asyncio +async def test_generated_session_rpc_checks_active_callback_before_request(): + client = AsyncMock() + api = CommandsApi( + client, + "sess-1", + lambda: (_ for _ in ()).throw(RuntimeError("session inactive")), + ) + + with pytest.raises(RuntimeError, match="session inactive"): + await api.invoke(CommandsInvokeRequest(name="help")) + + client.request.assert_not_called() diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index 03a8da8e8..8f4266eff 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -3720,7 +3720,10 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio // Emit the common service struct (unexported, shared by all API groups via type cast) lines.push(`type ${serviceName} struct {`); lines.push(`\tclient *jsonrpc2.Client`); - if (isSession) lines.push(`\tsessionID string`); + if (isSession) { + lines.push(`\tsessionID string`); + lines.push(`\tassertActive func() error`); + } lines.push(`}`); lines.push(``); @@ -3764,11 +3767,15 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio } // Constructor - const ctorParams = isSession ? "client *jsonrpc2.Client, sessionID string" : "client *jsonrpc2.Client"; + const ctorParams = isSession ? "client *jsonrpc2.Client, sessionID string, assertActive ...func() error" : "client *jsonrpc2.Client"; lines.push(`func New${wrapperName}(${ctorParams}) *${wrapperName} {`); lines.push(`\tr := &${wrapperName}{}`); if (isSession) { - lines.push(`\tr.common = ${serviceName}{client: client, sessionID: sessionID}`); + lines.push(`\tvar assertActiveFn func() error`); + lines.push(`\tif len(assertActive) > 0 {`); + lines.push(`\t\tassertActiveFn = assertActive[0]`); + lines.push(`\t}`); + lines.push(`\tr.common = ${serviceName}{client: client, sessionID: sessionID, assertActive: assertActiveFn}`); } else { lines.push(`\tr.common = ${serviceName}{client: client}`); } @@ -3805,6 +3812,7 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc // For wrapper-level methods, access fields through a.common; for service type aliases, use a directly const clientRef = isWrapper ? "a.common.client" : "a.client"; const sessionIDRef = isWrapper ? "a.common.sessionID" : "a.sessionID"; + const assertActiveRef = isWrapper ? "a.common.assertActive" : "a.assertActive"; pushGoRpcMethodComment( lines, @@ -3834,6 +3842,18 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc lines.push(`\t\trequestParams = params[0]`); lines.push(`\t}`); } + if (isSession) { + lines.push(`\tif ${assertActiveRef} != nil {`); + lines.push(`\t\tif err := ${assertActiveRef}(); err != nil {`); + lines.push(`\t\t\treturn nil, err`); + lines.push(`\t\t}`); + lines.push(`\t}`); + } + if (hasParams && !paramsAreOptional && hasRequiredNonSessionParams) { + lines.push(`\tif ${paramsRef} == nil {`); + lines.push(`\t\treturn nil, errors.New("params is required")`); + lines.push(`\t}`); + } if (isSession) { lines.push(`\treq := map[string]any{"sessionId": ${sessionIDRef}}`); diff --git a/scripts/codegen/python.ts b/scripts/codegen/python.ts index 52b11ed59..f80a408be 100644 --- a/scripts/codegen/python.ts +++ b/scripts/codegen/python.ts @@ -2514,12 +2514,13 @@ function emitPyApiGroup( } lines.push(`class ${apiName}:`); if (isSession) { - lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`); + lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None):`); lines.push(` self._client = client`); lines.push(` self._session_id = session_id`); + lines.push(` self._assert_active = assert_active`); for (const [subGroupName] of subGroups) { const subApiName = apiName.replace(/Api$/, "") + toPascalCase(subGroupName) + "Api"; - lines.push(` self.${toSnakeCase(subGroupName)} = ${subApiName}(client, session_id)`); + lines.push(` self.${toSnakeCase(subGroupName)} = ${subApiName}(client, session_id, assert_active)`); } } else { lines.push(` def __init__(self, client: "JsonRpcClient"):`); @@ -2559,11 +2560,12 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio lines.push(classPrefix === "_Internal" ? ` """Internal SDK session-scoped RPC methods. Not part of the public API."""` : ` """Typed session-scoped RPC methods."""`); - lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`); + lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str, assert_active: Callable[[], None] | None = None):`); lines.push(` self._client = client`); lines.push(` self._session_id = session_id`); + lines.push(` self._assert_active = assert_active`); for (const [groupName] of groups) { - lines.push(` self.${toSnakeCase(groupName)} = ${classPrefix}${toPascalCase(groupName)}Api(client, session_id)`); + lines.push(` self.${toSnakeCase(groupName)} = ${classPrefix}${toPascalCase(groupName)}Api(client, session_id, assert_active)`); } } else { lines.push(`class ${wrapperName}:`); @@ -2608,9 +2610,11 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession: const effectiveParams = getMethodParamsSchema(method); const paramProps = effectiveParams?.properties || {}; const nonSessionParams = Object.keys(paramProps).filter((k) => k !== "sessionId"); + const requiredParams = new Set(effectiveParams?.required || []); const hasParams = isSession ? nonSessionParams.length > 0 : hasSchemaPayload(effectiveParams); const paramsType = resolveType(pythonParamsTypeName(method)); - const paramsOptional = isParamsOptional(method); + const hasRequiredNonSessionParams = nonSessionParams.some((name) => requiredParams.has(name)); + const paramsOptional = isParamsOptional(method) || !hasRequiredNonSessionParams; // Build signature with typed params + optional timeout const sig = hasParams @@ -2630,6 +2634,18 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession: internal: method.visibility === "internal", }); + if (isSession) { + lines.push(` if self._assert_active is not None:`); + lines.push(` self._assert_active()`); + } + if (hasParams && !paramsOptional) { + lines.push(` if params is None:`); + lines.push(` raise TypeError("params is required")`); + } + if (isSession || (hasParams && !paramsOptional)) { + lines.push(``); + } + // Deserialize helper const innerTypeName = hasNullableResult ? resolveType(pythonResultTypeName(method, nullableInner)) : resultType; const deserialize = (expr: string) => { diff --git a/scripts/codegen/typescript.ts b/scripts/codegen/typescript.ts index 3afaec395..104d00430 100644 --- a/scripts/codegen/typescript.ts +++ b/scripts/codegen/typescript.ts @@ -656,7 +656,7 @@ function hasInternalMethods(node: Record): boolean { if (schema.session) { lines.push(`/** Create typed session-scoped RPC methods. */`); - lines.push(`export function createSessionRpc(connection: MessageConnection, sessionId: string) {`); + lines.push(`export function createSessionRpc(connection: MessageConnection, sessionId: string, assertActive?: () => void) {`); lines.push(` return {`); lines.push(...emitGroup(schema.session, " ", true, false, false, "public")); lines.push(` };`); @@ -669,7 +669,7 @@ function hasInternalMethods(node: Record): boolean { lines.push(` * surface. Not exported on the public client API.`); lines.push(` * @internal`); lines.push(` */`); - lines.push(`export function createInternalSessionRpc(connection: MessageConnection, sessionId: string) {`); + lines.push(`export function createInternalSessionRpc(connection: MessageConnection, sessionId: string, assertActive?: () => void) {`); lines.push(` return {`); lines.push(...emitGroup(schema.session, " ", true, false, false, "internal")); lines.push(` };`); @@ -709,15 +709,18 @@ function emitGroup( const paramEntries = effectiveParams?.properties ? Object.entries(effectiveParams.properties).filter(([k]) => k !== "sessionId") : []; + const requiredParams = new Set(effectiveParams?.required ?? []); const hasParams = hasSchemaPayload(effectiveParams); const hasNonSessionParams = paramEntries.length > 0; + const hasRequiredNonSessionParams = paramEntries.some(([name]) => requiredParams.has(name)); + const paramsOptional = isParamsOptional(value) || !hasRequiredNonSessionParams; const sigParams: string[] = []; let bodyArg: string; if (isSession) { if (hasNonSessionParams) { - const optMark = isParamsOptional(value) ? "?" : ""; + const optMark = paramsOptional ? "?" : ""; // sessionId is already stripped from the generated type definition, // so no need for Omit<..., "sessionId"> sigParams.push(`params${optMark}: ${paramsType}`); @@ -727,9 +730,9 @@ function emitGroup( } } else { if (hasParams) { - const optMark = isParamsOptional(value) ? "?" : ""; + const optMark = paramsOptional ? "?" : ""; sigParams.push(`params${optMark}: ${paramsType}`); - bodyArg = "params"; + bodyArg = paramsOptional ? "(params ?? {})" : "params"; } else { bodyArg = "{}"; } @@ -741,8 +744,17 @@ function emitGroup( includeDeprecated: (value as RpcMethod).deprecated && !parentDeprecated, includeExperimental: (value as RpcMethod).stability === "experimental" && !parentExperimental, }); - lines.push(`${indent}${key}: async (${sigParams.join(", ")}): Promise<${resultType}> =>`); - lines.push(`${indent} connection.sendRequest("${rpcMethod}", ${bodyArg}),`); + lines.push(`${indent}${key}: async (${sigParams.join(", ")}): Promise<${resultType}> => {`); + if (isSession) { + lines.push(`${indent} assertActive?.();`); + } + if (sigParams.length > 0 && !paramsOptional) { + lines.push(`${indent} if (params == null) {`); + lines.push(`${indent} throw new TypeError("params is required");`); + lines.push(`${indent} }`); + } + lines.push(`${indent} return connection.sendRequest("${rpcMethod}", ${bodyArg});`); + lines.push(`${indent}},`); } else if (typeof value === "object" && value !== null) { const groupExperimental = isNodeFullyExperimental(value as Record); const groupDeprecated = isNodeFullyDeprecated(value as Record);