diff --git a/go/README.md b/go/README.md index bbed46f0f..6c304d075 100644 --- a/go/README.md +++ b/go/README.md @@ -110,6 +110,8 @@ That's it! When your application calls `copilot.NewClient` without a `CLIPath` n - `SetForegroundSessionID(ctx context.Context, sessionID string) error` - Request TUI to display a specific session (TUI+server mode only) - `On(handler SessionLifecycleHandler) func()` - Subscribe to all lifecycle events; returns unsubscribe function - `OnEventType(eventType SessionLifecycleEventType, handler SessionLifecycleHandler) func()` - Subscribe to specific lifecycle event type +- `CreateCloudSession(options *CloudSessionOptions) (*CloudSession, error)` - Create a sandbox-backed cloud session through Mission Control +- `ConnectCloudSession(taskOrSessionID string, options *CloudConnectOptions) (*CloudSession, error)` - Attach to an existing Mission Control cloud task **Session Lifecycle Events:** @@ -183,6 +185,40 @@ Event types: `SessionLifecycleCreated`, `SessionLifecycleDeleted`, `SessionLifec - `UI() *SessionUI` - Interactive UI API for elicitation dialogs - `Capabilities() SessionCapabilities` - Host capabilities (e.g. elicitation support) +### CloudSession + +A `CloudSession` is a remote-control handle to a cloud sandbox running through Mission Control. It does not spawn a local CLI process. + +- `Connect() error` - Fetch initial events and start background polling (called internally by `CreateCloudSession`/`ConnectCloudSession`) +- `On(handler CloudSessionEventHandler) func()` - Subscribe to events (returns unsubscribe function) +- `Send(options MessageOptions) error` - Send a user message through the steer API +- `SendAndWait(options MessageOptions, timeout time.Duration) (*CloudSessionEvent, error)` - Send and wait for session.idle +- `Abort() error` - Abort the currently running agent work +- `SubmitRemoteCommand(cmdType MissionControlCommandType, content string) error` - Send a raw steering command +- `RespondToPermission(payload CloudPermissionResponsePayload) error` - Respond to a permission request +- `RespondToAskUser(payload CloudAskUserResponsePayload) error` - Respond to an ask-user request +- `RespondToElicitation(payload CloudElicitationResponsePayload) error` - Respond to an elicitation +- `RespondToPlanApproval(payload CloudPlanApprovalResponsePayload) error` - Respond to a plan approval +- `SwitchMode(payload CloudModeSwitchPayload) error` - Switch the session mode +- `GetMessages() []CloudSessionEvent` - Get all received events +- `Disconnect()` - Stop polling and release resources + +**CloudSessionOptions:** + +- `Owner` (string): Billing/authorization owner. Required when `Repository` is omitted. +- `Repository` (*CloudRepository): Repository context (`Owner`, `Name`, optional `Branch`). +- `MissionControlBaseURL` (string): Override Mission Control base URL (default: `$COPILOT_MC_BASE_URL` or `$COPILOT_API_BASE_URL/agents`). +- `FrontendBaseURL` (string): Override GitHub frontend URL (default: `$COPILOT_MC_FRONTEND_URL` or `https://github.com`). +- `AuthToken` (string): Override auth token (default: `$COPILOT_MC_ACCESS_TOKEN` or `GitHubToken`). +- `IntegrationID` (string): Override `Copilot-Integration-Id` header (default: `copilot-cli`). +- `PollIntervalMs` (int): Event poll interval in ms (default: 5000). +- `InitialEventTimeoutMs` (*int): How long to wait for the first event in ms (default: 10000). Use `Int(0)` to skip waiting. +- `OnProgress` (func(CloudProgressEvent)): Progress callback. +- `OnCloudTaskCreated` (func(MissionControlTask)): Called after task creation. +- `OnEventPollError` (func(error)): Called on poll failures. + +**Steering command types:** `CommandUserMessage`, `CommandAskUserResponse`, `CommandPlanApproval`, `CommandPermissionResponse`, `CommandElicitation`, `CommandAbort`, `CommandModeSwitch` + ### Helper Functions - `Bool(v bool) *bool` - Helper to create bool pointers for `AutoStart` option @@ -856,6 +892,78 @@ Communicates with CLI via TCP socket. Useful for distributed scenarios. ## Environment Variables - `COPILOT_CLI_PATH` - Path to the Copilot CLI executable +- `COPILOT_MC_BASE_URL` - Mission Control API base URL (cloud sessions) +- `COPILOT_API_BASE_URL` - Copilot API base URL (cloud sessions fallback) +- `COPILOT_MC_FRONTEND_URL` - GitHub frontend base URL (cloud sessions) +- `COPILOT_MC_ACCESS_TOKEN` - Mission Control auth token (cloud sessions) + +## Cloud Sessions + +Cloud sessions run agents inside provisioned cloud sandboxes managed by +Mission Control. The SDK acts as a remote-control client — it creates or +attaches to a cloud task, polls for events, and steers the agent. + +### Create a cloud session (with repository) + +```go +client := copilot.NewClient(&copilot.ClientOptions{ + AutoStart: copilot.Bool(false), + GitHubToken: "ghp_...", +}) + +session, err := client.CreateCloudSession(&copilot.CloudSessionOptions{ + Repository: &copilot.CloudRepository{ + Owner: "github", + Name: "copilot-sdk", + }, +}) +if err != nil { + log.Fatal(err) +} +defer session.Disconnect() + +session.On(func(event copilot.CloudSessionEvent) { + fmt.Println(event.Type) +}) + +session.Send(copilot.MessageOptions{Prompt: "Hello cloud!"}) +``` + +### Create a repo-less cloud session + +```go +session, err := client.CreateCloudSession(&copilot.CloudSessionOptions{ + Owner: "github", +}) +``` + +### Attach to an existing cloud task + +```go +session, err := client.ConnectCloudSession("task-id", nil) +``` + +### Steering commands + +```go +// Send a message +session.Send(copilot.MessageOptions{Prompt: "Fix the bug"}) + +// Abort current work +session.Abort() + +// Respond to permission / ask-user / elicitation / plan approval +session.RespondToPermission(copilot.CloudPermissionResponsePayload{ + PromptID: "p1", Approved: true, Scope: "once", +}) +session.RespondToAskUser(copilot.CloudAskUserResponsePayload{ + PromptID: "p2", Answer: "yes", WasFreeform: false, +}) +session.SwitchMode(copilot.CloudModeSwitchPayload{Mode: "autopilot"}) + +// Raw steering +session.SubmitRemoteCommand(copilot.CommandUserMessage, "raw content") +``` ## License diff --git a/go/cloud_client.go b/go/cloud_client.go new file mode 100644 index 000000000..5e6643f2f --- /dev/null +++ b/go/cloud_client.go @@ -0,0 +1,270 @@ +package copilot + +import ( + "errors" + "os" + "strings" + "time" +) + +// CreateCloudSession creates a sandbox-backed cloud session through Mission +// Control and attaches to it as a remote-control client. +// +// This does not create a local runtime session. The agent runs inside the +// provisioned cloud sandbox; this SDK instance polls Mission Control for +// events and sends user actions through the task steer API. +// +// Either options.Repository or options.Owner must be provided; when Repository +// is omitted, Owner is required for billing and authorization. +// +// Example: +// +// session, err := client.CreateCloudSession(&copilot.CloudSessionOptions{ +// Repository: &copilot.CloudRepository{Owner: "github", Name: "copilot-sdk"}, +// }) +// if err != nil { +// log.Fatal(err) +// } +// defer session.Disconnect() +// +// session.On(func(event copilot.CloudSessionEvent) { +// fmt.Println(event.Type) +// }) +// session.Send(copilot.MessageOptions{Prompt: "Hello from the cloud!"}) +func (c *Client) CreateCloudSession(options *CloudSessionOptions) (*CloudSession, error) { + if options == nil { + options = &CloudSessionOptions{} + } + startedAt := time.Now() + + mcClient := c.buildMissionControlClient( + options.MissionControlBaseURL, + options.CopilotAPIBaseURL, + options.FrontendBaseURL, + options.AuthToken, + options.IntegrationID, + ) + + owner := strings.TrimSpace(options.Owner) + repo := options.Repository + + if repo == nil && owner == "" { + return nil, errors.New("CloudSessionOptions.Owner is required when Repository is omitted") + } + + if options.OnProgress != nil { + options.OnProgress(CloudProgressEvent{Phase: CloudProgressCreatingTask, ElapsedMs: 0}) + options.OnProgress(CloudProgressEvent{ + Phase: CloudProgressProvisioningSandbox, + ElapsedMs: time.Since(startedAt).Milliseconds(), + }) + } + + var taskRepo *CloudRepository + if repo != nil { + taskRepo = &CloudRepository{Owner: repo.Owner, Name: repo.Name} + } + task, err := mcClient.createCloudTask(owner, taskRepo) + if err != nil { + return nil, err + } + if options.OnCloudTaskCreated != nil { + options.OnCloudTaskCreated(*task) + } + + if options.OnProgress != nil { + options.OnProgress(CloudProgressEvent{ + Phase: CloudProgressWaitingForSession, + ElapsedMs: time.Since(startedAt).Milliseconds(), + TaskID: task.ID, + }) + } + + session := newCloudSession(cloudSessionConfig{ + client: mcClient, + metadata: c.buildCloudSessionMetadata(task, mcClient, repo, owner), + pollIntervalMs: options.PollIntervalMs, + initialEventTimeoutMs: options.InitialEventTimeoutMs, + initialEventPollIntervalMs: options.InitialEventPollIntervalMs, + onEventPollError: options.OnEventPollError, + }) + + if err := session.Connect(); err != nil { + return nil, err + } + + if options.OnProgress != nil { + options.OnProgress(CloudProgressEvent{ + Phase: CloudProgressConnected, + ElapsedMs: time.Since(startedAt).Milliseconds(), + TaskID: task.ID, + }) + } + + return session, nil +} + +// ConnectCloudSession attaches to an existing Mission Control cloud task as a +// remote-control client. +// +// The taskOrSessionID is treated as a Mission Control task ID. If Mission +// Control returns task metadata, it is used to populate the session metadata; +// otherwise the SDK still attaches by polling task events for the provided ID. +// +// Example: +// +// session, err := client.ConnectCloudSession("task-1", nil) +// if err != nil { +// log.Fatal(err) +// } +// defer session.Disconnect() +func (c *Client) ConnectCloudSession(taskOrSessionID string, options *CloudConnectOptions) (*CloudSession, error) { + if options == nil { + options = &CloudConnectOptions{} + } + startedAt := time.Now() + + mcClient := c.buildMissionControlClient( + options.MissionControlBaseURL, + options.CopilotAPIBaseURL, + options.FrontendBaseURL, + options.AuthToken, + options.IntegrationID, + ) + + if options.OnProgress != nil { + options.OnProgress(CloudProgressEvent{ + Phase: CloudProgressWaitingForSession, + ElapsedMs: 0, + TaskID: taskOrSessionID, + }) + } + + owner := strings.TrimSpace(options.Owner) + task, err := mcClient.getTask(taskOrSessionID) + if err != nil { + return nil, err + } + + var metadata CloudSessionMetadata + if task != nil { + metadata = c.buildCloudSessionMetadata(task, mcClient, options.Repository, owner) + } else { + metadata = c.buildFallbackCloudSessionMetadata(taskOrSessionID, mcClient, options.Repository, owner) + } + + session := newCloudSession(cloudSessionConfig{ + client: mcClient, + metadata: metadata, + pollIntervalMs: options.PollIntervalMs, + initialEventTimeoutMs: options.InitialEventTimeoutMs, + initialEventPollIntervalMs: options.InitialEventPollIntervalMs, + onEventPollError: options.OnEventPollError, + }) + + if err := session.Connect(); err != nil { + return nil, err + } + + if options.OnProgress != nil { + options.OnProgress(CloudProgressEvent{ + Phase: CloudProgressConnected, + ElapsedMs: time.Since(startedAt).Milliseconds(), + TaskID: metadata.TaskID, + }) + } + + return session, nil +} + +// ── helpers ───────────────────────────────────────────────────────────── + +func (c *Client) buildMissionControlClient( + mcBaseURL, copilotAPIBaseURL, frontendBaseURL, authToken, integrationID string, +) *missionControlClient { + copilotAPI := firstNonEmpty( + copilotAPIBaseURL, + os.Getenv("COPILOT_API_BASE_URL"), + os.Getenv("COPILOT_API_URL"), + "https://api.githubcopilot.com", + ) + copilotAPI = strings.TrimRight(copilotAPI, "/") + + baseURL := firstNonEmpty( + mcBaseURL, + os.Getenv("COPILOT_MC_BASE_URL"), + copilotAPI+"/agents", + ) + + token := firstNonEmpty( + strings.TrimSpace(authToken), + strings.TrimSpace(os.Getenv("COPILOT_MC_ACCESS_TOKEN")), + c.options.GitHubToken, + ) + + frontend := firstNonEmpty( + frontendBaseURL, + os.Getenv("COPILOT_MC_FRONTEND_URL"), + "https://github.com", + ) + + return newMissionControlClient(missionControlClientConfig{ + BaseURL: baseURL, + AuthToken: token, + IntegrationID: integrationID, + FrontendBaseURL: frontend, + }) +} + +func (c *Client) buildCloudSessionMetadata( + task *MissionControlTask, + mcClient *missionControlClient, + repository *CloudRepository, + owner string, +) CloudSessionMetadata { + var mcSessionID string + if len(task.Sessions) > 0 { + mcSessionID = task.Sessions[len(task.Sessions)-1].ID + } + + createdAt, _ := time.Parse(time.RFC3339, task.CreatedAt) + updatedAt, _ := time.Parse(time.RFC3339, task.UpdatedAt) + + return CloudSessionMetadata{ + TaskID: task.ID, + MissionControlSessionID: mcSessionID, + FrontendURL: mcClient.getFrontendURL(task.ID), + Owner: owner, + Repository: repository, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + State: task.State, + Status: task.Status, + } +} + +func (c *Client) buildFallbackCloudSessionMetadata( + taskID string, + mcClient *missionControlClient, + repository *CloudRepository, + owner string, +) CloudSessionMetadata { + now := time.Now() + return CloudSessionMetadata{ + TaskID: taskID, + FrontendURL: mcClient.getFrontendURL(taskID), + Owner: owner, + Repository: repository, + CreatedAt: now, + UpdatedAt: now, + } +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if v != "" { + return v + } + } + return "" +} diff --git a/go/cloud_mission_control_client.go b/go/cloud_mission_control_client.go new file mode 100644 index 000000000..511391d17 --- /dev/null +++ b/go/cloud_mission_control_client.go @@ -0,0 +1,316 @@ +package copilot + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// CloudSandboxAgentSlug is the agent slug sent when creating cloud sandbox tasks. +const CloudSandboxAgentSlug = "copilot-developer-sandbox" + +const ( + defaultRequestTimeoutMs = 10_000 + defaultCreateCloudTaskTimeoutMs = 10 * 60 * 1000 +) + +// missionControlClientConfig holds the options for constructing a missionControlClient. +type missionControlClientConfig struct { + BaseURL string + AuthToken string + IntegrationID string + FrontendBaseURL string + RequestTimeoutMs int + CreateCloudTaskTimeoutMs int +} + +// missionControlClient talks to the Mission Control HTTP API. +type missionControlClient struct { + baseURL string + authToken string + integrationID string + frontendBaseURL string + requestTimeout time.Duration + createCloudTaskTimeout time.Duration + httpClient *http.Client +} + +func newMissionControlClient(cfg missionControlClientConfig) *missionControlClient { + baseURL := strings.TrimRight(cfg.BaseURL, "/") + authToken := strings.TrimSpace(cfg.AuthToken) + integrationID := cfg.IntegrationID + if integrationID == "" { + integrationID = "copilot-cli" + } + frontendBaseURL := strings.TrimRight(cfg.FrontendBaseURL, "/") + + reqTimeout := time.Duration(cfg.RequestTimeoutMs) * time.Millisecond + if reqTimeout <= 0 { + reqTimeout = time.Duration(defaultRequestTimeoutMs) * time.Millisecond + } + createTimeout := time.Duration(cfg.CreateCloudTaskTimeoutMs) * time.Millisecond + if createTimeout <= 0 { + createTimeout = time.Duration(defaultCreateCloudTaskTimeoutMs) * time.Millisecond + } + + return &missionControlClient{ + baseURL: baseURL, + authToken: authToken, + integrationID: integrationID, + frontendBaseURL: frontendBaseURL, + requestTimeout: reqTimeout, + createCloudTaskTimeout: createTimeout, + httpClient: &http.Client{}, + } +} + +func (mc *missionControlClient) createCloudTask(owner string, repo *CloudRepository) (*MissionControlTask, error) { + body := make(map[string]any) + if owner != "" { + body["owner"] = owner + } + if repo != nil { + body["repositories"] = []map[string]string{ + {"owner": repo.Owner, "name": repo.Name}, + } + } + + var task MissionControlTask + if err := mc.requestJSON( + http.MethodPost, + mc.baseURL+"/tasks", + body, + mc.createCloudTaskTimeout, + map[string]string{"X-Copilot-Agent-Slug": CloudSandboxAgentSlug}, + &task, + ); err != nil { + return nil, err + } + return &task, nil +} + +func (mc *missionControlClient) listTaskEvents(taskID string) ([]CloudSessionEvent, error) { + var wrapper struct { + Events []CloudSessionEvent `json:"events"` + } + if err := mc.requestJSON( + http.MethodGet, + mc.baseURL+"/tasks/"+url.PathEscape(taskID)+"/events", + nil, + mc.requestTimeout, + nil, + &wrapper, + ); err != nil { + return nil, err + } + // Filter events that have valid id, timestamp, and type fields. + valid := make([]CloudSessionEvent, 0, len(wrapper.Events)) + for _, e := range wrapper.Events { + if e.ID != "" && e.Timestamp != "" && e.Type != "" { + valid = append(valid, e) + } + } + return valid, nil +} + +// steerRequest is the JSON body for the steer API. +type steerRequest struct { + Type MissionControlCommandType `json:"type"` + Content string `json:"content,omitempty"` +} + +func (mc *missionControlClient) steerTask(taskID string, cmdType MissionControlCommandType, content string) error { + body := steerRequest{Type: cmdType} + if content != "" { + body.Content = content + } + return mc.requestOKOnly( + http.MethodPost, + mc.baseURL+"/tasks/"+url.PathEscape(taskID)+"/steer", + body, + mc.requestTimeout, + nil, + ) +} + +func (mc *missionControlClient) getTask(taskID string) (*MissionControlTask, error) { + var task MissionControlTask + if err := mc.requestJSON( + http.MethodGet, + mc.baseURL+"/tasks/"+url.PathEscape(taskID), + nil, + mc.requestTimeout, + nil, + &task, + ); err != nil { + var csErr *CloudSessionError + if ok := errorAs(err, &csErr); ok && csErr.Status == 404 { + return nil, nil + } + return nil, err + } + return &task, nil +} + +func (mc *missionControlClient) getFrontendURL(taskID string) string { + return mc.frontendBaseURL + "/copilot/tasks/" + url.PathEscape(taskID) +} + +// ── HTTP helpers ──────────────────────────────────────────────────────── + +func (mc *missionControlClient) headers(extra map[string]string) map[string]string { + h := map[string]string{ + "Content-Type": "application/json", + "Copilot-Integration-Id": mc.integrationID, + } + if mc.authToken != "" { + h["Authorization"] = "Bearer " + mc.authToken + } + for k, v := range extra { + h[k] = v + } + return h +} + +func (mc *missionControlClient) requestJSON(method, reqURL string, reqBody any, timeout time.Duration, extra map[string]string, out any) error { + resp, err := mc.doRequest(method, reqURL, reqBody, timeout, extra) + if err != nil { + return err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return &CloudSessionError{Message: fmt.Sprintf("failed to read response body: %s", err), Reason: CloudFailureServer} + } + if len(data) == 0 { + return nil + } + if err := json.Unmarshal(data, out); err != nil { + return &CloudSessionError{ + Message: fmt.Sprintf("Mission Control returned invalid JSON: %s", err), + Reason: CloudFailureServer, + } + } + return nil +} + +func (mc *missionControlClient) requestOKOnly(method, reqURL string, reqBody any, timeout time.Duration, extra map[string]string) error { + resp, err := mc.doRequest(method, reqURL, reqBody, timeout, extra) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +func (mc *missionControlClient) doRequest(method, reqURL string, reqBody any, timeout time.Duration, extra map[string]string) (*http.Response, error) { + var bodyReader io.Reader + if reqBody != nil { + data, err := json.Marshal(reqBody) + if err != nil { + return nil, &CloudSessionError{Message: fmt.Sprintf("failed to marshal request: %s", err), Reason: CloudFailureServer} + } + bodyReader = bytes.NewReader(data) + } + + req, err := http.NewRequest(method, reqURL, bodyReader) + if err != nil { + return nil, &CloudSessionError{Message: fmt.Sprintf("failed to create request: %s", err), Reason: CloudFailureNetwork} + } + for k, v := range mc.headers(extra) { + req.Header.Set(k, v) + } + + client := &http.Client{Timeout: timeout} + resp, err := client.Do(req) + if err != nil { + if isTimeoutError(err) { + return nil, &CloudSessionError{Message: "Mission Control request timed out", Reason: CloudFailureTimeout} + } + return nil, &CloudSessionError{ + Message: fmt.Sprintf("Mission Control request failed: %s", err), + Reason: CloudFailureNetwork, + } + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + msg := extractMissionControlMessage(body) + if msg == "" { + msg = fmt.Sprintf("Mission Control request failed with HTTP %d", resp.StatusCode) + } + return nil, &CloudSessionError{ + Message: msg, + Reason: reasonForStatus(resp.StatusCode), + Status: resp.StatusCode, + } + } + + return resp, nil +} + +func reasonForStatus(status int) CloudSessionFailureReason { + if status == 403 { + return CloudFailurePolicyBlocked + } + if status == 400 || status == 422 { + return CloudFailureValidation + } + return CloudFailureServer +} + +func extractMissionControlMessage(body []byte) string { + if len(body) == 0 { + return "" + } + var parsed struct { + Message string `json:"message"` + } + if err := json.Unmarshal(body, &parsed); err == nil && parsed.Message != "" { + return parsed.Message + } + return string(body) +} + +func isTimeoutError(err error) bool { + // net/http wraps timeout errors with a *url.Error whose Timeout() returns true. + type timeouter interface{ Timeout() bool } + if te, ok := err.(timeouter); ok { + return te.Timeout() + } + return false +} + +// errorAs is a thin wrapper around errors.As to support type-parameterized targets. +func errorAs(err error, target any) bool { + type asIface interface { + As(any) bool + } + switch t := target.(type) { + case **CloudSessionError: + for err != nil { + if csErr, ok := err.(*CloudSessionError); ok { + *t = csErr + return true + } + u, ok := err.(interface{ Unwrap() error }) + if !ok { + return false + } + err = u.Unwrap() + } + return false + default: + if a, ok := err.(asIface); ok { + return a.As(target) + } + return false + } +} diff --git a/go/cloud_session.go b/go/cloud_session.go new file mode 100644 index 000000000..b7dfd1df5 --- /dev/null +++ b/go/cloud_session.go @@ -0,0 +1,474 @@ +package copilot + +import ( + "encoding/json" + "errors" + "fmt" + "sort" + "sync" + "time" +) + +const ( + defaultPollIntervalMs = 5_000 + defaultInitialEventTimeoutMs = 10_000 + defaultInitialEventPollIntervalMs = 500 +) + +// cloudSessionConfig is the internal configuration for creating a CloudSession. +type cloudSessionConfig struct { + client *missionControlClient + metadata CloudSessionMetadata + pollIntervalMs int + initialEventTimeoutMs *int // nil = use default + initialEventPollIntervalMs int + onEventPollError func(error) +} + +// CloudSession represents a remote cloud sandbox session controlled through +// the Mission Control API. It polls for task events, dispatches them to +// registered handlers, and provides methods for steering the remote agent. +// +// A CloudSession does not run a local CLI process. All agent work happens in +// the cloud sandbox; this object is the remote-control client. +type CloudSession struct { + client *missionControlClient + pollInterval time.Duration + initialEventTimeout time.Duration + initialEventPollInterval time.Duration + onEventPollError func(error) + + mu sync.Mutex + handlers []cloudHandler + nextHandlerID uint64 + events []CloudSessionEvent + seenEventIDs map[string]bool + seenIDsAtLastTimestamp map[string]bool + lastSeenTimestamp string + stopPolling chan struct{} + isPolling bool + isDisconnected bool + remoteSteerable bool + + // SessionID is the Mission Control session identifier. + SessionID string + + // Metadata holds the full metadata for this cloud session. + Metadata CloudSessionMetadata +} + +type cloudHandler struct { + id uint64 + fn CloudSessionEventHandler +} + +func newCloudSession(cfg cloudSessionConfig) *CloudSession { + pollMs := cfg.pollIntervalMs + if pollMs <= 0 { + pollMs = defaultPollIntervalMs + } + initTimeoutMs := defaultInitialEventTimeoutMs + if cfg.initialEventTimeoutMs != nil { + initTimeoutMs = *cfg.initialEventTimeoutMs + if initTimeoutMs < 0 { + initTimeoutMs = 0 + } + } + initPollMs := cfg.initialEventPollIntervalMs + if initPollMs <= 0 { + initPollMs = defaultInitialEventPollIntervalMs + } + + sessionID := cfg.metadata.MissionControlSessionID + if sessionID == "" { + sessionID = cfg.metadata.TaskID + } + + return &CloudSession{ + client: cfg.client, + pollInterval: time.Duration(pollMs) * time.Millisecond, + initialEventTimeout: time.Duration(initTimeoutMs) * time.Millisecond, + initialEventPollInterval: time.Duration(initPollMs) * time.Millisecond, + onEventPollError: cfg.onEventPollError, + handlers: make([]cloudHandler, 0), + events: make([]CloudSessionEvent, 0), + seenEventIDs: make(map[string]bool), + seenIDsAtLastTimestamp: make(map[string]bool), + remoteSteerable: true, + SessionID: sessionID, + Metadata: cfg.metadata, + } +} + +// Connect fetches initial events and starts background polling. +// It must be called before sending messages or subscribing to events. +func (cs *CloudSession) Connect() error { + initial, err := cs.waitForInitialEvents() + if err != nil { + return err + } + cs.recordEvents(initial) + cs.startEventPolling() + return nil +} + +// On registers a handler that is called for every cloud session event. +// Returns an unsubscribe function. +func (cs *CloudSession) On(handler CloudSessionEventHandler) func() { + cs.mu.Lock() + defer cs.mu.Unlock() + + id := cs.nextHandlerID + cs.nextHandlerID++ + cs.handlers = append(cs.handlers, cloudHandler{id: id, fn: handler}) + + return func() { + cs.mu.Lock() + defer cs.mu.Unlock() + for i, h := range cs.handlers { + if h.id == id { + cs.handlers = append(cs.handlers[:i], cs.handlers[i+1:]...) + break + } + } + } +} + +// Send sends a user message to the cloud session through the steer API. +func (cs *CloudSession) Send(options MessageOptions) error { + cs.mu.Lock() + if cs.isDisconnected { + cs.mu.Unlock() + return errors.New("cloud session is disconnected") + } + cs.mu.Unlock() + return cs.SubmitRemoteCommand(CommandUserMessage, options.Prompt) +} + +// SendAndWait sends a user message and blocks until the session becomes idle +// or the timeout elapses. Returns the last assistant.message event received, +// or nil if none. The default timeout is 60 seconds. +func (cs *CloudSession) SendAndWait(options MessageOptions, timeout time.Duration) (*CloudSessionEvent, error) { + if timeout <= 0 { + timeout = 60 * time.Second + } + + var lastAssistant *CloudSessionEvent + var mu sync.Mutex + idleCh := make(chan struct{}, 1) + errCh := make(chan error, 1) + + unsubscribe := cs.On(func(event CloudSessionEvent) { + switch event.Type { + case "assistant.message": + mu.Lock() + copied := event + lastAssistant = &copied + mu.Unlock() + case "session.idle": + select { + case idleCh <- struct{}{}: + default: + } + case "session.error": + var msg string + var data struct{ Message string } + if err := json.Unmarshal(event.Data, &data); err == nil { + msg = data.Message + } + select { + case errCh <- fmt.Errorf("session error: %s", msg): + default: + } + } + }) + defer unsubscribe() + + if err := cs.Send(options); err != nil { + return nil, err + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-idleCh: + mu.Lock() + result := lastAssistant + mu.Unlock() + return result, nil + case err := <-errCh: + return nil, err + case <-timer.C: + return nil, fmt.Errorf("timeout after %s waiting for session.idle", timeout) + } +} + +// Abort aborts the currently running agent work. +func (cs *CloudSession) Abort() error { + cs.mu.Lock() + if cs.isDisconnected { + cs.mu.Unlock() + return errors.New("cloud session is disconnected") + } + cs.mu.Unlock() + return cs.SubmitRemoteCommand(CommandAbort, "") +} + +// SubmitRemoteCommand sends a steering command to the cloud session. +func (cs *CloudSession) SubmitRemoteCommand(cmdType MissionControlCommandType, content string) error { + cs.mu.Lock() + if cs.isDisconnected { + cs.mu.Unlock() + return errors.New("cloud session is disconnected") + } + if !cs.remoteSteerable { + cs.mu.Unlock() + return errors.New("this session is read-only — remote steering is not enabled") + } + cs.mu.Unlock() + return cs.client.steerTask(cs.Metadata.TaskID, cmdType, content) +} + +// RespondToPermission sends a permission response to the cloud session. +func (cs *CloudSession) RespondToPermission(payload CloudPermissionResponsePayload) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + return cs.SubmitRemoteCommand(CommandPermissionResponse, string(data)) +} + +// RespondToAskUser sends an ask-user response to the cloud session. +func (cs *CloudSession) RespondToAskUser(payload CloudAskUserResponsePayload) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + return cs.SubmitRemoteCommand(CommandAskUserResponse, string(data)) +} + +// RespondToElicitation sends an elicitation response to the cloud session. +func (cs *CloudSession) RespondToElicitation(payload CloudElicitationResponsePayload) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + return cs.SubmitRemoteCommand(CommandElicitation, string(data)) +} + +// RespondToPlanApproval sends a plan-approval response to the cloud session. +func (cs *CloudSession) RespondToPlanApproval(payload CloudPlanApprovalResponsePayload) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + return cs.SubmitRemoteCommand(CommandPlanApproval, string(data)) +} + +// SwitchMode sends a mode-switch command to the cloud session. +func (cs *CloudSession) SwitchMode(payload CloudModeSwitchPayload) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + return cs.SubmitRemoteCommand(CommandModeSwitch, string(data)) +} + +// GetMessages returns a copy of all events received so far, in chronological order. +func (cs *CloudSession) GetMessages() []CloudSessionEvent { + cs.mu.Lock() + defer cs.mu.Unlock() + out := make([]CloudSessionEvent, len(cs.events)) + copy(out, cs.events) + return out +} + +// Disconnect stops event polling and clears all handlers. The session cannot +// be used after disconnecting. +func (cs *CloudSession) Disconnect() { + cs.mu.Lock() + defer cs.mu.Unlock() + + if cs.isDisconnected { + return + } + cs.isDisconnected = true + cs.stopEventPolling() + cs.handlers = nil +} + +// ── internal polling ──────────────────────────────────────────────────── + +func (cs *CloudSession) startEventPolling() { + cs.mu.Lock() + defer cs.mu.Unlock() + + if cs.stopPolling != nil || cs.isDisconnected { + return + } + + stop := make(chan struct{}) + cs.stopPolling = stop + + go func() { + ticker := time.NewTicker(cs.pollInterval) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + cs.pollEvents() + } + } + }() +} + +// stopEventPolling must be called with cs.mu held. +func (cs *CloudSession) stopEventPolling() { + if cs.stopPolling != nil { + close(cs.stopPolling) + cs.stopPolling = nil + } +} + +func (cs *CloudSession) waitForInitialEvents() ([]CloudSessionEvent, error) { + deadline := time.Now().Add(cs.initialEventTimeout) + for { + events, err := cs.client.listTaskEvents(cs.Metadata.TaskID) + if err != nil { + return nil, err + } + if len(events) > 0 { + return sortEventsChronologically(events), nil + } + if cs.initialEventTimeout <= 0 || time.Now().After(deadline) { + return nil, nil + } + time.Sleep(cs.initialEventPollInterval) + } +} + +func (cs *CloudSession) pollEvents() { + cs.mu.Lock() + if cs.isPolling || cs.isDisconnected { + cs.mu.Unlock() + return + } + cs.isPolling = true + cs.mu.Unlock() + + defer func() { + cs.mu.Lock() + cs.isPolling = false + cs.mu.Unlock() + }() + + events, err := cs.client.listTaskEvents(cs.Metadata.TaskID) + if err != nil { + if cs.onEventPollError != nil { + cs.onEventPollError(err) + } + return + } + + newEvents := cs.collectNewEvents(events) + cs.recordEvents(newEvents) +} + +func (cs *CloudSession) collectNewEvents(events []CloudSessionEvent) []CloudSessionEvent { + cs.mu.Lock() + defer cs.mu.Unlock() + + var newEvents []CloudSessionEvent + for _, e := range events { + if cs.seenEventIDs[e.ID] { + continue + } + if cs.lastSeenTimestamp == "" { + newEvents = append(newEvents, e) + continue + } + cmp := compareTimestamps(e.Timestamp, cs.lastSeenTimestamp) + if cmp > 0 { + newEvents = append(newEvents, e) + } else if cmp == 0 && !cs.seenIDsAtLastTimestamp[e.ID] { + newEvents = append(newEvents, e) + } + } + + return sortEventsChronologically(newEvents) +} + +func (cs *CloudSession) recordEvents(events []CloudSessionEvent) { + cs.mu.Lock() + defer cs.mu.Unlock() + + for _, e := range sortEventsChronologically(events) { + if cs.seenEventIDs[e.ID] { + continue + } + cs.seenEventIDs[e.ID] = true + cs.events = append(cs.events, e) + cs.markEventTimestamp(e) + cs.updateRemoteSteerable(e) + cs.dispatchEvent(e) + } +} + +func (cs *CloudSession) markEventTimestamp(e CloudSessionEvent) { + if cs.lastSeenTimestamp != e.Timestamp { + cs.lastSeenTimestamp = e.Timestamp + cs.seenIDsAtLastTimestamp = make(map[string]bool) + } + cs.seenIDsAtLastTimestamp[e.ID] = true +} + +func (cs *CloudSession) updateRemoteSteerable(e CloudSessionEvent) { + if e.Type == "session.remote_steerable_changed" { + var data struct { + RemoteSteerable bool `json:"remoteSteerable"` + } + if err := json.Unmarshal(e.Data, &data); err == nil { + cs.remoteSteerable = data.RemoteSteerable + } + } +} + +// dispatchEvent must be called with cs.mu held. +func (cs *CloudSession) dispatchEvent(e CloudSessionEvent) { + // Copy handlers to allow safe iteration; callers may unsubscribe inside handlers. + handlers := make([]cloudHandler, len(cs.handlers)) + copy(handlers, cs.handlers) + + for _, h := range handlers { + func() { + defer func() { recover() }() // keep one failing handler from breaking polling + h.fn(e) + }() + } +} + +func sortEventsChronologically(events []CloudSessionEvent) []CloudSessionEvent { + sorted := make([]CloudSessionEvent, len(events)) + copy(sorted, events) + sort.Slice(sorted, func(i, j int) bool { + cmp := compareTimestamps(sorted[i].Timestamp, sorted[j].Timestamp) + if cmp != 0 { + return cmp < 0 + } + return sorted[i].ID < sorted[j].ID + }) + return sorted +} + +func compareTimestamps(a, b string) int { + if a < b { + return -1 + } + if a > b { + return 1 + } + return 0 +} diff --git a/go/cloud_session_test.go b/go/cloud_session_test.go new file mode 100644 index 000000000..6837a6a63 --- /dev/null +++ b/go/cloud_session_test.go @@ -0,0 +1,661 @@ +package copilot + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +// ── helpers ───────────────────────────────────────────────────────────── + +func newTestMCServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + return httptest.NewServer(handler) +} + +func jsonBytes(v any) []byte { + b, _ := json.Marshal(v) + return b +} + +var testTask = MissionControlTask{ + ID: "task-1", + Name: "Cloud task", + State: "running", + Status: "ready", + CreatorID: 1, + OwnerID: 2, + RepoID: Int(3), + SessionCount: 1, + CreatedAt: "2026-05-11T10:00:00.000Z", + UpdatedAt: "2026-05-11T10:01:00.000Z", + Sessions: []MissionControlTaskSession{ + { + ID: "mc-session-1", + TaskID: "task-1", + State: "running", + CreatedAt: "2026-05-11T10:00:30.000Z", + UpdatedAt: "2026-05-11T10:00:30.000Z", + OwnerID: 2, + RepoID: Int(3), + }, + }, +} + +var requestedEvent = CloudSessionEvent{ + ID: "event-1", + Timestamp: "2026-05-11T10:00:00.000Z", + Type: "session.requested", +} + +var idleEvent = CloudSessionEvent{ + ID: "event-2", + Timestamp: "2026-05-11T10:00:01.000Z", + Type: "session.idle", + Data: json.RawMessage(`{}`), +} + +// ── tests ─────────────────────────────────────────────────────────────── + +func TestCreateCloudSession_WithRepository(t *testing.T) { + var capturedRequests []capturedRequest + var mu sync.Mutex + + call := 0 + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + capturedRequests = append(capturedRequests, capture(r)) + mu.Unlock() + + switch call { + case 0: // createCloudTask + w.Header().Set("Content-Type", "application/json") + w.Write(jsonBytes(testTask)) + case 1: // listTaskEvents + w.Header().Set("Content-Type", "application/json") + w.Write(jsonBytes(map[string]any{"events": []CloudSessionEvent{requestedEvent}})) + default: + w.WriteHeader(500) + } + call++ + }) + defer server.Close() + + progress := []CloudProgressPhase{} + + client := NewClient(&ClientOptions{AutoStart: Bool(false), GitHubToken: "token-1"}) + session, err := client.CreateCloudSession(&CloudSessionOptions{ + Repository: &CloudRepository{Owner: "github", Name: "copilot-sdk", Branch: "main"}, + MissionControlBaseURL: server.URL, + FrontendBaseURL: "https://github.test", + InitialEventTimeoutMs: Int(0), + OnProgress: func(e CloudProgressEvent) { + progress = append(progress, e.Phase) + }, + }) + if err != nil { + t.Fatalf("CreateCloudSession failed: %v", err) + } + defer session.Disconnect() + + // Verify metadata + if session.Metadata.TaskID != "task-1" { + t.Errorf("expected taskId task-1, got %s", session.Metadata.TaskID) + } + if session.Metadata.MissionControlSessionID != "mc-session-1" { + t.Errorf("expected MC session mc-session-1, got %s", session.Metadata.MissionControlSessionID) + } + if session.Metadata.FrontendURL != "https://github.test/copilot/tasks/task-1" { + t.Errorf("unexpected frontendUrl: %s", session.Metadata.FrontendURL) + } + if session.Metadata.Repository == nil || session.Metadata.Repository.Owner != "github" { + t.Error("expected repository metadata to be set") + } + if session.Metadata.State != "running" || session.Metadata.Status != "ready" { + t.Errorf("unexpected state/status: %s/%s", session.Metadata.State, session.Metadata.Status) + } + + // Verify events + msgs := session.GetMessages() + if len(msgs) != 1 || msgs[0].ID != "event-1" { + t.Errorf("expected 1 event, got %d", len(msgs)) + } + + // Verify progress + expectedProgress := []CloudProgressPhase{ + CloudProgressCreatingTask, + CloudProgressProvisioningSandbox, + CloudProgressWaitingForSession, + CloudProgressConnected, + } + if len(progress) != len(expectedProgress) { + t.Fatalf("expected %d progress events, got %d: %v", len(expectedProgress), len(progress), progress) + } + for i, p := range expectedProgress { + if progress[i] != p { + t.Errorf("progress[%d]: expected %s, got %s", i, p, progress[i]) + } + } + + // Verify create task request + mu.Lock() + createReq := capturedRequests[0] + mu.Unlock() + if createReq.method != "POST" { + t.Errorf("expected POST, got %s", createReq.method) + } + if createReq.headers["X-Copilot-Agent-Slug"] != CloudSandboxAgentSlug { + t.Errorf("missing agent slug header") + } + if createReq.headers["Authorization"] != "Bearer token-1" { + t.Errorf("unexpected auth header: %s", createReq.headers["Authorization"]) + } + + var body map[string]any + json.Unmarshal(createReq.body, &body) + repos, ok := body["repositories"] + if !ok { + t.Fatal("expected repositories in request body") + } + repoList, ok := repos.([]any) + if !ok || len(repoList) != 1 { + t.Fatal("expected 1 repository in request body") + } + repoMap := repoList[0].(map[string]any) + if repoMap["owner"] != "github" || repoMap["name"] != "copilot-sdk" { + t.Errorf("unexpected repo: %v", repoMap) + } +} + +func TestCreateCloudSession_RepoLessWithOwner(t *testing.T) { + call := 0 + var capturedBody []byte + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch call { + case 0: + b := make([]byte, r.ContentLength) + r.Body.Read(b) + capturedBody = b + w.Write(jsonBytes(testTask)) + case 1: + w.Write(jsonBytes(map[string]any{"events": []any{}})) + } + call++ + }) + defer server.Close() + + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + session, err := client.CreateCloudSession(&CloudSessionOptions{ + Owner: "github", + MissionControlBaseURL: server.URL, + InitialEventTimeoutMs: Int(0), + }) + if err != nil { + t.Fatalf("CreateCloudSession failed: %v", err) + } + defer session.Disconnect() + + if session.Metadata.Owner != "github" { + t.Errorf("expected owner github, got %s", session.Metadata.Owner) + } + + var body map[string]any + json.Unmarshal(capturedBody, &body) + if body["owner"] != "github" { + t.Errorf("expected owner in body, got %v", body) + } + if _, hasRepos := body["repositories"]; hasRepos { + t.Error("repo-less request should not have repositories key") + } +} + +func TestCreateCloudSession_RequiresOwnerWhenRepoOmitted(t *testing.T) { + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + + _, err := client.CreateCloudSession(&CloudSessionOptions{ + InitialEventTimeoutMs: Int(0), + }) + if err == nil { + t.Fatal("expected error when both owner and repository are omitted") + } + if err.Error() != "CloudSessionOptions.Owner is required when Repository is omitted" { + t.Errorf("unexpected error message: %s", err) + } +} + +func TestConnectCloudSession_SteerAPI(t *testing.T) { + call := 0 + var steerBody []byte + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch call { + case 0: // getTask → 404 + w.WriteHeader(404) + w.Write([]byte(`"not found"`)) + case 1: // listTaskEvents + w.Write(jsonBytes(map[string]any{"events": []any{}})) + case 2: // steer + b := make([]byte, r.ContentLength) + r.Body.Read(b) + steerBody = b + w.WriteHeader(202) + } + call++ + }) + defer server.Close() + + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + session, err := client.ConnectCloudSession("task-1", &CloudConnectOptions{ + MissionControlBaseURL: server.URL, + InitialEventTimeoutMs: Int(0), + }) + if err != nil { + t.Fatalf("ConnectCloudSession failed: %v", err) + } + defer session.Disconnect() + + if err := session.Send(MessageOptions{Prompt: "hello cloud"}); err != nil { + t.Fatalf("Send failed: %v", err) + } + + var steer steerRequest + if err := json.Unmarshal(steerBody, &steer); err != nil { + t.Fatalf("failed to parse steer body: %v", err) + } + if steer.Type != CommandUserMessage { + t.Errorf("expected user_message, got %s", steer.Type) + } + if steer.Content != "hello cloud" { + t.Errorf("expected content 'hello cloud', got %s", steer.Content) + } +} + +func TestCloudSession_SortsAndDeduplicatesEvents(t *testing.T) { + polledEvent := CloudSessionEvent{ + ID: "event-3", + Timestamp: "2026-05-11T10:00:02.000Z", + Type: "session.idle", + Data: json.RawMessage(`{}`), + } + + call := 0 + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch call { + case 0: // getTask + w.Write(jsonBytes(testTask)) + case 1: // listTaskEvents (initial) - note reversed order + w.Write(jsonBytes(map[string]any{"events": []CloudSessionEvent{idleEvent, requestedEvent}})) + default: // listTaskEvents (poll) - includes old + new + w.Write(jsonBytes(map[string]any{"events": []CloudSessionEvent{idleEvent, requestedEvent, polledEvent}})) + } + call++ + }) + defer server.Close() + + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + session, err := client.ConnectCloudSession("task-1", &CloudConnectOptions{ + MissionControlBaseURL: server.URL, + InitialEventTimeoutMs: Int(0), + PollIntervalMs: 10, + }) + if err != nil { + t.Fatalf("ConnectCloudSession failed: %v", err) + } + defer session.Disconnect() + + // Initial events should be sorted chronologically + msgs := session.GetMessages() + if len(msgs) != 2 { + t.Fatalf("expected 2 initial events, got %d", len(msgs)) + } + if msgs[0].ID != "event-1" || msgs[1].ID != "event-2" { + t.Errorf("initial events not sorted: %s, %s", msgs[0].ID, msgs[1].ID) + } + + // Wait for a poll cycle + seenCh := make(chan string, 10) + session.On(func(event CloudSessionEvent) { + seenCh <- event.ID + }) + + // Wait for the polled event + timeout := time.After(2 * time.Second) + select { + case id := <-seenCh: + if id != "event-3" { + t.Errorf("expected event-3, got %s", id) + } + case <-timeout: + t.Fatal("timed out waiting for polled event") + } + + // Total events should be 3 with no duplicates + allMsgs := session.GetMessages() + if len(allMsgs) != 3 { + t.Fatalf("expected 3 total events, got %d", len(allMsgs)) + } + expectedIDs := []string{"event-1", "event-2", "event-3"} + for i, expected := range expectedIDs { + if allMsgs[i].ID != expected { + t.Errorf("event[%d]: expected %s, got %s", i, expected, allMsgs[i].ID) + } + } +} + +func TestCloudSession_MissionControlErrorResponse(t *testing.T) { + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(403) + w.Write(jsonBytes(map[string]string{"message": "blocked"})) + }) + defer server.Close() + + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + _, err := client.CreateCloudSession(&CloudSessionOptions{ + Repository: &CloudRepository{Owner: "github", Name: "copilot-sdk"}, + MissionControlBaseURL: server.URL, + InitialEventTimeoutMs: Int(0), + }) + if err == nil { + t.Fatal("expected error") + } + + csErr, ok := err.(*CloudSessionError) + if !ok { + t.Fatalf("expected CloudSessionError, got %T: %v", err, err) + } + if csErr.Message != "blocked" { + t.Errorf("expected message 'blocked', got %s", csErr.Message) + } + if csErr.Reason != CloudFailurePolicyBlocked { + t.Errorf("expected reason policy_blocked, got %s", csErr.Reason) + } + if csErr.Status != 403 { + t.Errorf("expected status 403, got %d", csErr.Status) + } +} + +func TestCloudSession_DisconnectPreventsSteer(t *testing.T) { + call := 0 + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch call { + case 0: + w.WriteHeader(404) + w.Write([]byte(`"not found"`)) + case 1: + w.Write(jsonBytes(map[string]any{"events": []any{}})) + } + call++ + }) + defer server.Close() + + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + session, err := client.ConnectCloudSession("task-1", &CloudConnectOptions{ + MissionControlBaseURL: server.URL, + InitialEventTimeoutMs: Int(0), + }) + if err != nil { + t.Fatalf("ConnectCloudSession failed: %v", err) + } + + session.Disconnect() + + if err := session.Send(MessageOptions{Prompt: "should fail"}); err == nil { + t.Fatal("expected error after disconnect") + } +} + +func TestCloudSession_RemoteSteerableChanged(t *testing.T) { + steerableEvent := CloudSessionEvent{ + ID: "event-steerable", + Timestamp: "2026-05-11T10:00:03.000Z", + Type: "session.remote_steerable_changed", + Data: json.RawMessage(`{"remoteSteerable": false}`), + } + + call := 0 + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch call { + case 0: + w.WriteHeader(404) + w.Write([]byte(`"not found"`)) + case 1: + w.Write(jsonBytes(map[string]any{"events": []CloudSessionEvent{requestedEvent, steerableEvent}})) + } + call++ + }) + defer server.Close() + + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + session, err := client.ConnectCloudSession("task-1", &CloudConnectOptions{ + MissionControlBaseURL: server.URL, + InitialEventTimeoutMs: Int(0), + }) + if err != nil { + t.Fatalf("ConnectCloudSession failed: %v", err) + } + defer session.Disconnect() + + err = session.Send(MessageOptions{Prompt: "should fail"}) + if err == nil { + t.Fatal("expected error when not steerable") + } + if err.Error() != "this session is read-only — remote steering is not enabled" { + t.Errorf("unexpected error: %s", err) + } +} + +// ── capture helper ────────────────────────────────────────────────────── + +type capturedRequest struct { + method string + path string + headers map[string]string + body []byte +} + +func capture(r *http.Request) capturedRequest { + body := make([]byte, 0) + if r.Body != nil { + body = make([]byte, r.ContentLength) + r.Body.Read(body) + } + headers := make(map[string]string) + for k := range r.Header { + headers[k] = r.Header.Get(k) + } + return capturedRequest{ + method: r.Method, + path: r.URL.Path, + headers: headers, + body: body, + } +} + +func TestCloudSession_OnEventHandler(t *testing.T) { + call := 0 + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch call { + case 0: + w.WriteHeader(404) + w.Write([]byte(`"not found"`)) + case 1: + w.Write(jsonBytes(map[string]any{"events": []CloudSessionEvent{requestedEvent, idleEvent}})) + } + call++ + }) + defer server.Close() + + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + session, err := client.ConnectCloudSession("task-1", &CloudConnectOptions{ + MissionControlBaseURL: server.URL, + InitialEventTimeoutMs: Int(0), + }) + if err != nil { + t.Fatalf("ConnectCloudSession failed: %v", err) + } + defer session.Disconnect() + + // Handler registered after connect should not receive replayed initial events + // but getMessages should still return them. + var handlerEvents []string + session.On(func(event CloudSessionEvent) { + handlerEvents = append(handlerEvents, event.ID) + }) + + msgs := session.GetMessages() + if len(msgs) != 2 { + t.Fatalf("expected 2 initial events, got %d", len(msgs)) + } + if msgs[0].ID != "event-1" || msgs[1].ID != "event-2" { + t.Errorf("unexpected event order: %s, %s", msgs[0].ID, msgs[1].ID) + } + + // No handler events since handler was registered after connect + if len(handlerEvents) != 0 { + t.Errorf("handler should not have received replay events, got %v", handlerEvents) + } +} + +func TestCloudSession_UnsubscribeHandler(t *testing.T) { + polledEvent := CloudSessionEvent{ + ID: "event-3", + Timestamp: "2026-05-11T10:00:02.000Z", + Type: "session.idle", + Data: json.RawMessage(`{}`), + } + + call := 0 + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch call { + case 0: + w.WriteHeader(404) + w.Write([]byte(`"not found"`)) + case 1: + w.Write(jsonBytes(map[string]any{"events": []CloudSessionEvent{requestedEvent}})) + default: + w.Write(jsonBytes(map[string]any{"events": []CloudSessionEvent{requestedEvent, polledEvent}})) + } + call++ + }) + defer server.Close() + + client := NewClient(&ClientOptions{AutoStart: Bool(false)}) + session, err := client.ConnectCloudSession("task-1", &CloudConnectOptions{ + MissionControlBaseURL: server.URL, + InitialEventTimeoutMs: Int(0), + PollIntervalMs: 10, + }) + if err != nil { + t.Fatalf("ConnectCloudSession failed: %v", err) + } + defer session.Disconnect() + + callCount := 0 + var countMu sync.Mutex + unsubscribe := session.On(func(event CloudSessionEvent) { + countMu.Lock() + callCount++ + countMu.Unlock() + }) + + // Immediately unsubscribe — the handler should never fire for polled events + unsubscribe() + + time.Sleep(50 * time.Millisecond) + + countMu.Lock() + count := callCount + countMu.Unlock() + if count != 0 { + t.Errorf("handler was called %d times after unsubscribe", count) + } +} + +func TestMissionControlClient_ErrorExtraction(t *testing.T) { + tests := []struct { + name string + status int + body string + expectedReason CloudSessionFailureReason + expectedMsg string + }{ + { + name: "403 with JSON message", + status: 403, + body: `{"message": "policy blocked"}`, + expectedReason: CloudFailurePolicyBlocked, + expectedMsg: "policy blocked", + }, + { + name: "400 validation error", + status: 400, + body: `{"message": "invalid request"}`, + expectedReason: CloudFailureValidation, + expectedMsg: "invalid request", + }, + { + name: "422 validation error", + status: 422, + body: `{"message": "unprocessable"}`, + expectedReason: CloudFailureValidation, + expectedMsg: "unprocessable", + }, + { + name: "500 server error with no body", + status: 500, + body: "", + expectedReason: CloudFailureServer, + expectedMsg: fmt.Sprintf("Mission Control request failed with HTTP %d", 500), + }, + { + name: "500 with plain text body", + status: 500, + body: "internal error", + expectedReason: CloudFailureServer, + expectedMsg: "internal error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := newTestMCServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.status) + w.Write([]byte(tt.body)) + }) + defer server.Close() + + mc := newMissionControlClient(missionControlClientConfig{ + BaseURL: server.URL, + FrontendBaseURL: "https://github.test", + }) + + _, err := mc.createCloudTask("owner", nil) + if err == nil { + t.Fatal("expected error") + } + csErr, ok := err.(*CloudSessionError) + if !ok { + t.Fatalf("expected CloudSessionError, got %T", err) + } + if csErr.Reason != tt.expectedReason { + t.Errorf("expected reason %s, got %s", tt.expectedReason, csErr.Reason) + } + if csErr.Message != tt.expectedMsg { + t.Errorf("expected message %q, got %q", tt.expectedMsg, csErr.Message) + } + if csErr.Status != tt.status { + t.Errorf("expected status %d, got %d", tt.status, csErr.Status) + } + }) + } +} diff --git a/go/cloud_types.go b/go/cloud_types.go new file mode 100644 index 000000000..7b80250aa --- /dev/null +++ b/go/cloud_types.go @@ -0,0 +1,245 @@ +package copilot + +import ( + "encoding/json" + "time" +) + +// CloudRepository describes the repository context used when creating a cloud +// sandbox task. Branch is optional. +type CloudRepository struct { + Owner string `json:"owner"` + Name string `json:"name"` + Branch string `json:"branch,omitempty"` +} + +// CloudProgressPhase represents a progress phase emitted while creating or +// attaching to a cloud sandbox session. +type CloudProgressPhase string + +const ( + CloudProgressCreatingTask CloudProgressPhase = "creating_task" + CloudProgressProvisioningSandbox CloudProgressPhase = "provisioning_sandbox" + CloudProgressWaitingForSession CloudProgressPhase = "waiting_for_session" + CloudProgressConnected CloudProgressPhase = "connected" +) + +// CloudProgressEvent is emitted during cloud session creation to report +// progress to the caller. +type CloudProgressEvent struct { + Phase CloudProgressPhase `json:"phase"` + ElapsedMs int64 `json:"elapsedMs,omitempty"` + TaskID string `json:"taskId,omitempty"` +} + +// CloudSessionFailureReason categorises the cause of a cloud session error. +type CloudSessionFailureReason string + +const ( + CloudFailurePolicyBlocked CloudSessionFailureReason = "policy_blocked" + CloudFailureValidation CloudSessionFailureReason = "validation" + CloudFailureTimeout CloudSessionFailureReason = "timeout" + CloudFailureNetwork CloudSessionFailureReason = "network" + CloudFailureServer CloudSessionFailureReason = "server" +) + +// MissionControlTaskSession describes a single session within a Mission Control task. +type MissionControlTaskSession struct { + ID string `json:"id"` + TaskID string `json:"task_id"` + AgentTaskID *string `json:"agent_task_id,omitempty"` + State string `json:"state"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Name *string `json:"name,omitempty"` + OwnerID int `json:"owner_id"` + RepoID *int `json:"repo_id,omitempty"` +} + +// MissionControlTask is the top-level task object returned by Mission Control. +type MissionControlTask struct { + ID string `json:"id"` + Name string `json:"name"` + State string `json:"state"` + Status string `json:"status"` + CreatorID int `json:"creator_id"` + OwnerID int `json:"owner_id"` + RepoID *int `json:"repo_id,omitempty"` + SessionCount int `json:"session_count"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Sessions []MissionControlTaskSession `json:"sessions,omitempty"` +} + +// CloudSessionMetadata carries the metadata associated with a cloud session. +type CloudSessionMetadata struct { + TaskID string `json:"taskId"` + MissionControlSessionID string `json:"missionControlSessionId,omitempty"` + FrontendURL string `json:"frontendUrl"` + Owner string `json:"owner,omitempty"` + Repository *CloudRepository `json:"repository,omitempty"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + State string `json:"state,omitempty"` + Status string `json:"status,omitempty"` +} + +// CloudSessionEvent represents a single event received from Mission Control's +// task event stream. The Data field holds arbitrary event payload data. +type CloudSessionEvent struct { + ID string `json:"id"` + ParentID *string `json:"parentId"` + Timestamp string `json:"timestamp"` + Type string `json:"type"` + Data json.RawMessage `json:"data,omitempty"` + Ephemeral *bool `json:"ephemeral,omitempty"` +} + +// MissionControlCommandType enumerates the steering command types accepted +// by the Mission Control steer API. +type MissionControlCommandType string + +const ( + CommandUserMessage MissionControlCommandType = "user_message" + CommandAskUserResponse MissionControlCommandType = "ask_user_response" + CommandPlanApproval MissionControlCommandType = "plan_approval_response" + CommandPermissionResponse MissionControlCommandType = "permission_response" + CommandElicitation MissionControlCommandType = "elicitation_response" + CommandAbort MissionControlCommandType = "abort" + CommandModeSwitch MissionControlCommandType = "mode_switch" +) + +// CloudAskUserResponsePayload is the payload for an ask-user steering response. +type CloudAskUserResponsePayload struct { + PromptID string `json:"promptId"` + Answer string `json:"answer"` + WasFreeform bool `json:"wasFreeform"` + Dismissed bool `json:"dismissed,omitempty"` +} + +// CloudPlanApprovalResponsePayload is the payload for a plan-approval steering response. +type CloudPlanApprovalResponsePayload struct { + PromptID string `json:"promptId"` + Approved bool `json:"approved"` + SelectedAction string `json:"selectedAction,omitempty"` + AutoApproveEdits bool `json:"autoApproveEdits,omitempty"` + Feedback string `json:"feedback,omitempty"` +} + +// CloudPermissionResponsePayload is the payload for a permission steering response. +type CloudPermissionResponsePayload struct { + PromptID string `json:"promptId"` + Approved bool `json:"approved"` + Scope string `json:"scope"` // "once" or "session" +} + +// CloudElicitationResponsePayload is the payload for an elicitation steering response. +type CloudElicitationResponsePayload struct { + PromptID string `json:"promptId"` + Action string `json:"action"` // "accept", "decline", or "cancel" + Content map[string]any `json:"content,omitempty"` +} + +// CloudModeSwitchPayload is the payload for a mode-switch steering command. +type CloudModeSwitchPayload struct { + Mode string `json:"mode"` // "interactive", "plan", or "autopilot" +} + +// CloudSessionError is returned when a Mission Control API request fails. +type CloudSessionError struct { + Message string + Reason CloudSessionFailureReason + Status int +} + +func (e *CloudSessionError) Error() string { + return e.Message +} + +// CloudSessionEventHandler is a callback invoked for every cloud session event. +type CloudSessionEventHandler func(CloudSessionEvent) + +// CloudSessionOptions configures the creation of a new cloud session via +// [Client.CreateCloudSession]. +type CloudSessionOptions struct { + // Owner is the billing/authorization owner for repo-less cloud sandboxes. + // Required when Repository is omitted. + Owner string + + // Repository provides the repository context for the cloud sandbox. + Repository *CloudRepository + + // MissionControlBaseURL overrides the Mission Control API base URL. + MissionControlBaseURL string + + // CopilotAPIBaseURL overrides the Copilot API base URL used to derive + // the Mission Control URL when MissionControlBaseURL is empty. + CopilotAPIBaseURL string + + // FrontendBaseURL overrides the GitHub frontend base URL (default: https://github.com). + FrontendBaseURL string + + // AuthToken overrides the authentication token sent to Mission Control. + AuthToken string + + // IntegrationID overrides the Copilot-Integration-Id header (default: "copilot-cli"). + IntegrationID string + + // PollIntervalMs is the interval between event polls in milliseconds (default: 5000). + PollIntervalMs int + + // InitialEventTimeoutMs is how long to wait for the first event in milliseconds (default: 10000). + InitialEventTimeoutMs *int + + // InitialEventPollIntervalMs is the poll interval during the initial wait in milliseconds (default: 500). + InitialEventPollIntervalMs int + + // OnProgress is called to report progress during session creation. + OnProgress func(CloudProgressEvent) + + // OnCloudTaskCreated is called after the Mission Control task is created. + OnCloudTaskCreated func(MissionControlTask) + + // OnEventPollError is called when an event poll fails. + OnEventPollError func(error) +} + +// CloudConnectOptions configures connecting to an existing cloud session via +// [Client.ConnectCloudSession]. +type CloudConnectOptions struct { + // Owner is the billing/authorization owner. + Owner string + + // Repository provides optional repository context. + Repository *CloudRepository + + // MissionControlBaseURL overrides the Mission Control API base URL. + MissionControlBaseURL string + + // CopilotAPIBaseURL overrides the Copilot API base URL. + CopilotAPIBaseURL string + + // FrontendBaseURL overrides the GitHub frontend base URL. + FrontendBaseURL string + + // AuthToken overrides the authentication token. + AuthToken string + + // IntegrationID overrides the Copilot-Integration-Id header. + IntegrationID string + + // PollIntervalMs is the interval between event polls in milliseconds (default: 5000). + PollIntervalMs int + + // InitialEventTimeoutMs is how long to wait for the first event in milliseconds (default: 10000). + InitialEventTimeoutMs *int + + // InitialEventPollIntervalMs is the poll interval during the initial wait in milliseconds (default: 500). + InitialEventPollIntervalMs int + + // OnProgress is called to report progress during connection. + OnProgress func(CloudProgressEvent) + + // OnEventPollError is called when an event poll fails. + OnEventPollError func(error) +}