Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions packages/core/src/webAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ export interface WebAgentOptions {
tabstackApiUrl?: string;
/** Callback for interactive mode: called when the agent needs user data for form fields. Presence enables interactive mode. */
onUserDataRequired?: UserDataCallback;
/** Correlation ID for this task, propagated to logs and traces. */
taskId?: string;
}

export interface ExecuteOptions {
Expand Down Expand Up @@ -214,6 +216,7 @@ export class WebAgent {
private readonly tabstackApiKey: string | undefined;
private readonly tabstackApiUrl: string | undefined;
private readonly onUserDataRequired: UserDataCallback | undefined;
private readonly taskId: string | undefined;

constructor(
private browser: AriaBrowser,
Expand All @@ -237,6 +240,7 @@ export class WebAgent {
this.tabstackApiKey = options.tabstackApiKey;
this.tabstackApiUrl = options.tabstackApiUrl;
this.onUserDataRequired = options.onUserDataRequired;
this.taskId = options.taskId;

if (this.searchProvider === "parallel-api" && !this.searchApiKey) {
throw new Error("parallel_api_key is required when search_provider is 'parallel-api'");
Expand Down Expand Up @@ -286,6 +290,7 @@ export class WebAgent {
attributes: {
"pilo.task": task,
...(options.startingUrl && { "pilo.url": options.startingUrl }),
...(this.taskId && { "pilo.task.id": this.taskId }),
},
},
async (span) => {
Expand Down
90 changes: 90 additions & 0 deletions packages/server/src/routes/pilo.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ vi.mock("../StreamLogger.js", () => ({
StreamLogger: vi.fn().mockImplementation(() => ({})),
}));

const UUID_RE = /^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i;

describe("Pilo Routes", () => {
let app: Hono;

Expand Down Expand Up @@ -196,6 +198,94 @@ describe("Pilo Routes", () => {
expect(res.headers.get("Connection")).toBe("keep-alive");
});

it("should include taskId in validation-error response", async () => {
const res = await app.request("/pilo/run", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({}),
});

expect(res.status).toBe(400);
const data = await res.json();
expect(data.error.taskId).toMatch(UUID_RE);
});

it("should include taskId in setup-error response for malformed JSON", async () => {
const res = await app.request("/pilo/run", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: "invalid json",
});

expect(res.status).toBe(500);
const data = await res.json();
expect(data.error.taskId).toMatch(UUID_RE);
});

it("should return x-pilo-task-id header on successful request", async () => {
const res = await app.request("/pilo/run", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ task: "test task" }),
});

expect(res.status).toBe(200);
expect(res.headers.get("x-pilo-task-id")).toMatch(UUID_RE);
});

it("should return x-pilo-task-id header on validation error", async () => {
const res = await app.request("/pilo/run", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({}),
});

expect(res.status).toBe(400);
expect(res.headers.get("x-pilo-task-id")).toMatch(UUID_RE);
});

it("should return x-pilo-task-id header on setup error", async () => {
const res = await app.request("/pilo/run", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: "invalid json",
});

expect(res.status).toBe(500);
expect(res.headers.get("x-pilo-task-id")).toMatch(UUID_RE);
});

it("should use the same taskId in header and body", async () => {
const res = await app.request("/pilo/run", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({}),
});

const data = await res.json();
expect(res.headers.get("x-pilo-task-id")).toBe(data.error.taskId);
});

it("should include taskId in SSE start event", async () => {
const res = await app.request("/pilo/run", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ task: "test task" }),
});

expect(res.status).toBe(200);
const reader = res.body!.getReader();
const { value } = await reader.read();
const chunk = new TextDecoder().decode(value);
expect(chunk).toMatch(/^event: start\ndata: /);

const dataLine = chunk.split("\n").find((l) => l.startsWith("data: "))!;
const data = JSON.parse(dataLine.slice("data: ".length));
expect(data.taskId).toMatch(UUID_RE);

reader.releaseLock();
});

it("should stream SSE events with proper format", async () => {
const res = await app.request("/pilo/run", {
method: "POST",
Expand Down
18 changes: 14 additions & 4 deletions packages/server/src/routes/pilo.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { randomUUID } from "node:crypto";
import { Hono } from "hono";
import { streamSSE } from "hono/streaming";
import { runTask, validateTaskRequest, createErrorResponse, errorToString } from "../taskRunner.js";
Expand All @@ -7,12 +8,20 @@ const pilo = new Hono();

// POST /pilo/run - Execute a Pilo task with real-time SSE streaming (non-interactive)
pilo.post("/run", async (c) => {
const taskId = randomUUID();
c.header("x-pilo-task-id", taskId);
try {
const body = (await c.req.json()) as PiloTaskRequest;

const validationError = validateTaskRequest(body);
if (validationError) {
return c.json(validationError.response, validationError.status as any);
return c.json(
{
...validationError.response,
error: { ...validationError.response.error, taskId },
},
validationError.status as any,
);
}

return streamSSE(c, async (stream) => {
Expand All @@ -26,11 +35,12 @@ pilo.post("/run", async (c) => {
try {
await stream.writeSSE({
event: "start",
data: JSON.stringify({ task: body.task, url: body.url }),
data: JSON.stringify({ taskId, task: body.task, url: body.url }),
});

const result = await runTask({
body,
taskId,
sendEvent: async (event, data) => {
await stream.writeSSE({ event, data: JSON.stringify(data) });
},
Expand All @@ -56,15 +66,15 @@ pilo.post("/run", async (c) => {
await stream.writeSSE({
event: "error",
data: JSON.stringify(
createErrorResponse(errorToString(error), "TASK_EXECUTION_FAILED"),
createErrorResponse(errorToString(error), "TASK_EXECUTION_FAILED", taskId),
),
});
}
}
});
} catch (error) {
console.error("Pilo task setup failed:", error);
return c.json(createErrorResponse(errorToString(error), "TASK_SETUP_FAILED"), 500);
return c.json(createErrorResponse(errorToString(error), "TASK_SETUP_FAILED", taskId), 500);
}
});

Expand Down
61 changes: 59 additions & 2 deletions packages/server/src/routes/piloWs.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,23 @@ const mockValidateTaskRequest = vi.fn().mockReturnValue(null);
vi.mock("../taskRunner.js", () => ({
runTask: (...args: any[]) => mockRunTask(...args),
validateTaskRequest: (...args: any[]) => mockValidateTaskRequest(...args),
createErrorResponse: (message: string, code: string) => ({
createErrorResponse: (message: string, code: string, taskId?: string) => ({
success: false,
error: { message, code, timestamp: new Date().toISOString() },
error: {
message,
code,
timestamp: new Date().toISOString(),
...(taskId !== undefined && { taskId }),
},
}),
errorToString: (error: unknown) => (error instanceof Error ? error.name : "Unknown error"),
}));

import { createPiloWsRoute } from "./piloWs.js";
import type { UpgradeWebSocket, WSContext } from "hono/ws";

const UUID_RE = /^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i;

/**
* Helper to extract the WebSocket event handlers from createPiloWsRoute.
*
Expand Down Expand Up @@ -296,6 +303,56 @@ describe("piloWs", () => {
});
});

describe("taskId", () => {
it("should emit task:accepted event with taskId after validation passes", async () => {
const h = createTestHarness();
h.sendMessage({ event: "task:details", data: { task: "test" } });
await vi.runAllTimersAsync();

const accepted = h.sentMessages.find((m) => m.event === "task:accepted");
expect(accepted).toBeDefined();
expect(accepted!.data.taskId).toMatch(UUID_RE);
});

it("should pass taskId to runTask", async () => {
const h = createTestHarness();
h.sendMessage({ event: "task:details", data: { task: "test" } });
await vi.runAllTimersAsync();

expect(mockRunTask).toHaveBeenCalledWith(
expect.objectContaining({ taskId: expect.stringMatching(UUID_RE) }),
);
});

it("should include taskId in validation-error response", () => {
mockValidateTaskRequest.mockReturnValue({
status: 400,
response: {
success: false,
error: { message: "Task is required", code: "MISSING_TASK", timestamp: "" },
},
});

const h = createTestHarness();
h.sendMessage({ event: "task:details", data: { task: "" } });

expect(h.sentMessages[0].event).toBe("error");
expect(h.sentMessages[0].data.error.taskId).toMatch(UUID_RE);
});

it("should include taskId in task-execution error response", async () => {
mockRunTask.mockRejectedValue(new TypeError("something broke"));

const h = createTestHarness();
h.sendMessage({ event: "task:details", data: { task: "test" } });
await vi.runAllTimersAsync();

const errorMsg = h.sentMessages.find((m) => m.event === "error");
expect(errorMsg).toBeDefined();
expect(errorMsg!.data.error.taskId).toMatch(UUID_RE);
});
});

describe("user_data_response", () => {
it("should reject when requestId is missing", () => {
const h = createTestHarness();
Expand Down
13 changes: 10 additions & 3 deletions packages/server/src/routes/piloWs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* { "event": "error", "data": ErrorResponse }
*/

import { randomUUID } from "node:crypto";
import { Hono } from "hono";
import type { UpgradeWebSocket, WSContext } from "hono/ws";
import type { UserDataCallback, UserDataResponse } from "pilo-core";
Expand Down Expand Up @@ -83,18 +84,23 @@ export function createPiloWsRoute(upgradeWebSocket: UpgradeWebSocket): Hono {
}

const body = msg.data as PiloTaskRequest;
const taskId = randomUUID();
if (!body) {
send(ws, "error", createErrorResponse("Missing data", "MISSING_DATA"));
send(ws, "error", createErrorResponse("Missing data", "MISSING_DATA", taskId));
return;
}

const validationError = validateTaskRequest(body);
if (validationError) {
send(ws, "error", validationError.response);
send(ws, "error", {
...validationError.response,
error: { ...validationError.response.error, taskId },
});
return;
}

taskRunning = true;
send(ws, "task:accepted", { taskId });

// The callback just blocks until the client responds.
// The event (interactive:form_data:request or interactive:form_data:error)
Expand All @@ -117,6 +123,7 @@ export function createPiloWsRoute(upgradeWebSocket: UpgradeWebSocket): Hono {
const result = await withRemoteContext(traceHeaders, () =>
runTask({
body,
taskId,
sendEvent: async (event, data) => {
send(ws, event, data);
},
Expand All @@ -131,7 +138,7 @@ export function createPiloWsRoute(upgradeWebSocket: UpgradeWebSocket): Hono {
await send(
ws,
"error",
createErrorResponse(errorToString(error), "TASK_EXECUTION_FAILED"),
createErrorResponse(errorToString(error), "TASK_EXECUTION_FAILED", taskId),
);
}
} finally {
Expand Down
23 changes: 23 additions & 0 deletions packages/server/src/taskRunner.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ describe("taskRunner", () => {
const parsed = new Date(res.error.timestamp);
expect(parsed.getTime()).not.toBeNaN();
});

it("should include taskId when provided", () => {
const res = createErrorResponse("msg", "CODE", "task-abc-123");
expect(res.error.taskId).toBe("task-abc-123");
});

it("should omit taskId when not provided", () => {
const res = createErrorResponse("msg", "CODE");
expect(res.error.taskId).toBeUndefined();
});
});

describe("validateTaskRequest", () => {
Expand Down Expand Up @@ -214,6 +224,19 @@ describe("taskRunner", () => {
expect(mockClose).toHaveBeenCalled();
});

it("should pass taskId to WebAgent constructor when provided", async () => {
await runTask({
body: { task: "test" },
sendEvent: vi.fn(),
abortSignal: new AbortController().signal,
taskId: "task-abc-123",
});

expect(mockConstructorSpy).toHaveBeenCalledWith(
expect.objectContaining({ taskId: "task-abc-123" }),
);
});

it("should not throw when agent.close fails", async () => {
mockClose = vi.fn().mockRejectedValue(new Error("close failed"));

Expand Down
Loading
Loading