diff --git a/apps/wh/migrations/0002_delivery-sent-tracking.sql b/apps/wh/migrations/0002_delivery-sent-tracking.sql new file mode 100644 index 00000000..a8345b8f --- /dev/null +++ b/apps/wh/migrations/0002_delivery-sent-tracking.sql @@ -0,0 +1,3 @@ +DROP INDEX `delivery_tunnel_delivery_idx`;--> statement-breakpoint +ALTER TABLE `delivery` ADD `sent_at` integer;--> statement-breakpoint +CREATE INDEX `delivery_tunnel_delivery_idx` ON `delivery` (`tunnel_id`,`delivered_at`,`sent_at`,`failed_at`,`received_at`); \ No newline at end of file diff --git a/apps/wh/migrations/meta/0002_snapshot.json b/apps/wh/migrations/meta/0002_snapshot.json new file mode 100644 index 00000000..cdecef51 --- /dev/null +++ b/apps/wh/migrations/meta/0002_snapshot.json @@ -0,0 +1,230 @@ +{ + "version": "6", + "dialect": "sqlite", + "id": "2fedf88b-a3db-4453-9423-6665b7180078", + "prevId": "617459ed-59ef-4877-b22e-cad7fe84b2f4", + "tables": { + "delivery": { + "name": "delivery", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "tunnel_id": { + "name": "tunnel_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "method": { + "name": "method", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "headers": { + "name": "headers", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "body": { + "name": "body", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "received_at": { + "name": "received_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "sent_at": { + "name": "sent_at", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "delivered_at": { + "name": "delivered_at", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "failed_at": { + "name": "failed_at", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "error": { + "name": "error", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + } + }, + "indexes": { + "delivery_tunnel_delivery_idx": { + "name": "delivery_tunnel_delivery_idx", + "columns": [ + "tunnel_id", + "delivered_at", + "sent_at", + "failed_at", + "received_at" + ], + "isUnique": false + } + }, + "foreignKeys": { + "delivery_tunnel_id_tunnel_id_fk": { + "name": "delivery_tunnel_id_tunnel_id_fk", + "tableFrom": "delivery", + "tableTo": "tunnel", + "columnsFrom": [ + "tunnel_id" + ], + "columnsTo": [ + "id" + ], + "onDelete": "cascade", + "onUpdate": "no action" + } + }, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "tunnel": { + "name": "tunnel", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "device_token_hash": { + "name": "device_token_hash", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "provider_id": { + "name": "provider_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "environment": { + "name": "environment", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "provider_account_id": { + "name": "provider_account_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "provider_webhook_endpoint_id": { + "name": "provider_webhook_endpoint_id", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "status": { + "name": "status", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "'active'" + }, + "created_at": { + "name": "created_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "last_seen_at": { + "name": "last_seen_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "disabled_at": { + "name": "disabled_at", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false + } + }, + "indexes": { + "tunnel_device_provider_unique": { + "name": "tunnel_device_provider_unique", + "columns": [ + "device_token_hash", + "provider_id", + "environment", + "provider_account_id" + ], + "isUnique": true + }, + "tunnel_device_idx": { + "name": "tunnel_device_idx", + "columns": [ + "device_token_hash" + ], + "isUnique": false + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + } + }, + "views": {}, + "enums": {}, + "_meta": { + "schemas": {}, + "tables": {}, + "columns": {} + }, + "internal": { + "indexes": {} + } +} \ No newline at end of file diff --git a/apps/wh/migrations/meta/_journal.json b/apps/wh/migrations/meta/_journal.json index c05b2a6e..83bb46bd 100644 --- a/apps/wh/migrations/meta/_journal.json +++ b/apps/wh/migrations/meta/_journal.json @@ -15,6 +15,13 @@ "when": 1778238601783, "tag": "0001_delivery-failure-tracking", "breakpoints": true + }, + { + "idx": 2, + "version": "6", + "when": 1778501327588, + "tag": "0002_delivery-sent-tracking", + "breakpoints": true } ] -} +} \ No newline at end of file diff --git a/apps/wh/package.json b/apps/wh/package.json index d1e5a0e2..8a070761 100644 --- a/apps/wh/package.json +++ b/apps/wh/package.json @@ -5,9 +5,9 @@ "type": "module", "scripts": { "dev": "wrangler dev --local", - "deploy": "if [ -z \"$PAYKIT_WEBHOOK_API_BASE_URL\" ]; then printf '%s\n' 'Missing PAYKIT_WEBHOOK_API_BASE_URL'; exit 1; fi; wrangler deploy --var \"PAYKIT_WEBHOOK_API_BASE_URL:$PAYKIT_WEBHOOK_API_BASE_URL\"", - "db:migrate:local": "wrangler d1 migrations apply paykit-wh --local", - "db:migrate:remote": "wrangler d1 migrations apply paykit-wh --remote", + "deploy": "if [ -z \"${PAYKIT_WEBHOOK_PUBLIC_BASE_URL:-$PAYKIT_WEBHOOK_API_BASE_URL}\" ]; then printf '%s\n' 'Missing PAYKIT_WEBHOOK_PUBLIC_BASE_URL'; exit 1; fi; wrangler deploy --var \"PAYKIT_WEBHOOK_PUBLIC_BASE_URL:${PAYKIT_WEBHOOK_PUBLIC_BASE_URL:-$PAYKIT_WEBHOOK_API_BASE_URL}\"", + "db:migrate:local": "wrangler d1 migrations apply DB --local", + "db:migrate:remote": "wrangler d1 migrations apply DB --remote", "db:studio": "drizzle-kit studio --config drizzle.config.ts", "ship": "bun run db:migrate:remote && bun run deploy", "typecheck": "tsc --noEmit" diff --git a/apps/wh/src/db/schema.ts b/apps/wh/src/db/schema.ts index 01928af2..ce4a8027 100644 --- a/apps/wh/src/db/schema.ts +++ b/apps/wh/src/db/schema.ts @@ -38,6 +38,7 @@ export const delivery = sqliteTable( headers: text("headers", { mode: "json" }).$type>().notNull(), body: text("body").notNull(), receivedAt: integer("received_at", { mode: "number" }).notNull(), + sentAt: integer("sent_at", { mode: "number" }), deliveredAt: integer("delivered_at", { mode: "number" }), failedAt: integer("failed_at", { mode: "number" }), error: text("error"), @@ -46,6 +47,7 @@ export const delivery = sqliteTable( index("delivery_tunnel_delivery_idx").on( table.tunnelId, table.deliveredAt, + table.sentAt, table.failedAt, table.receivedAt, ), diff --git a/apps/wh/src/index.ts b/apps/wh/src/index.ts index e298b483..e91b284b 100644 --- a/apps/wh/src/index.ts +++ b/apps/wh/src/index.ts @@ -4,17 +4,22 @@ import { Hono, type Context } from "hono"; import { HTTPException } from "hono/http-exception"; import { delivery, tunnel } from "./db/schema"; +import { TunnelObject, type TunnelObjectBindings } from "./tunnel-object"; -interface Bindings { - DB: D1Database; - MAX_BODY_BYTES: string; - MAX_DELIVERIES_PER_TUNNEL: string; - PAYKIT_WEBHOOK_API_BASE_URL?: string; - RETENTION_DAYS: string; +interface Bindings extends TunnelObjectBindings { + TUNNEL_OBJECT: DurableObjectNamespace; } const app = new Hono<{ Bindings: Bindings }>(); type AppContext = Context<{ Bindings: Bindings }>; +const MIN_CLI_VERSION = "0.0.4"; + +interface ParsedVersion { + major: number; + minor: number; + patch: number; + prerelease: string[]; +} function getDb(env: Bindings) { return drizzle(env.DB); @@ -28,15 +33,128 @@ function clamp(value: number, min: number, max: number): number { return Math.min(Math.max(value, min), max); } +function parseVersion(version: string): ParsedVersion | null { + const match = /^(\d+)\.(\d+)\.(\d+)(?:-([0-9A-Za-z.-]+))?(?:\+[0-9A-Za-z.-]+)?$/.exec(version); + if (!match) { + return null; + } + + return { + major: Number(match[1]), + minor: Number(match[2]), + patch: Number(match[3]), + prerelease: match[4]?.split(".") ?? [], + }; +} + +function comparePrerelease(left: string[], right: string[]): number { + if (left.length === 0 && right.length === 0) { + return 0; + } + + if (left.length === 0) { + return 1; + } + + if (right.length === 0) { + return -1; + } + + for (let index = 0; index < Math.max(left.length, right.length); index++) { + const leftPart = left[index]; + const rightPart = right[index]; + + if (leftPart === undefined) { + return -1; + } + + if (rightPart === undefined) { + return 1; + } + + const leftNumeric = /^\d+$/.test(leftPart); + const rightNumeric = /^\d+$/.test(rightPart); + + if (leftNumeric && rightNumeric) { + const leftNumber = Number(leftPart); + const rightNumber = Number(rightPart); + if (leftNumber !== rightNumber) { + return leftNumber > rightNumber ? 1 : -1; + } + continue; + } + + if (leftNumeric !== rightNumeric) { + return leftNumeric ? -1 : 1; + } + + if (leftPart !== rightPart) { + return leftPart > rightPart ? 1 : -1; + } + } + + return 0; +} + +function compareVersions(left: string, right: string): number { + const leftVersion = parseVersion(left); + const rightVersion = parseVersion(right); + + if (!leftVersion || !rightVersion) { + return -1; + } + + const leftParts = [leftVersion.major, leftVersion.minor, leftVersion.patch]; + const rightParts = [rightVersion.major, rightVersion.minor, rightVersion.patch]; + + for (let index = 0; index < Math.max(leftParts.length, rightParts.length); index++) { + const leftPart = leftParts[index] ?? 0; + const rightPart = rightParts[index] ?? 0; + if (leftPart !== rightPart) { + return leftPart > rightPart ? 1 : -1; + } + } + + return comparePrerelease(leftVersion.prerelease, rightVersion.prerelease); +} + +function getCliVersion(c: AppContext): string | undefined { + return c.req.header("x-paykit-cli-version") ?? c.req.query("cliVersion"); +} + +function getCliVersionFromRequest(request: Request): string | undefined { + return ( + request.headers.get("x-paykit-cli-version") ?? + new URL(request.url).searchParams.get("cliVersion") ?? + undefined + ); +} + +function createCliUpgradeResponse(): Response { + return Response.json( + { + code: "CLI_UPGRADE_REQUIRED", + message: `This paykitjs CLI version is no longer supported. Upgrade paykitjs to ${MIN_CLI_VERSION} or newer.`, + minVersion: MIN_CLI_VERSION, + }, + { status: 426 }, + ); +} + +function isSupportedCliVersion(version: string | undefined): boolean { + return typeof version === "string" && compareVersions(version, MIN_CLI_VERSION) >= 0; +} + function getNumericVar(value: string, fallback: number): number { const parsed = Number(value); return Number.isFinite(parsed) ? parsed : fallback; } function getRequiredWebhookBaseUrl(env: Bindings): string { - const baseUrl = env.PAYKIT_WEBHOOK_API_BASE_URL?.trim(); + const baseUrl = + env.PAYKIT_WEBHOOK_PUBLIC_BASE_URL?.trim() ?? env.PAYKIT_WEBHOOK_API_BASE_URL?.trim(); if (!baseUrl) { - throw new Error("PAYKIT_WEBHOOK_API_BASE_URL is required"); + throw new Error("PAYKIT_WEBHOOK_PUBLIC_BASE_URL is required"); } return baseUrl.replace(/\/$/, ""); @@ -124,11 +242,12 @@ function buildPullableDeliveryWhere(params: { const conditions = [ eq(delivery.tunnelId, params.tunnelId), isNull(delivery.deliveredAt), + isNull(delivery.sentAt), isNull(delivery.failedAt), ]; if (params.retryWindowMs > 0 && typeof params.includeFailedBefore === "number") { - conditions[2] = or( + conditions[3] = or( isNull(delivery.failedAt), and( lt(delivery.failedAt, params.includeFailedBefore), @@ -151,6 +270,20 @@ async function getPullableCount( return rows[0]?.count ?? 0; } +function getTunnelStub(env: Bindings, tunnelId: string) { + return env.TUNNEL_OBJECT.get(env.TUNNEL_OBJECT.idFromName(tunnelId)); +} + +async function notifyTunnelObject(env: Bindings, params: { tunnelId: string }): Promise { + const response = await getTunnelStub(env, params.tunnelId).fetch( + new Request("https://internal/internal/push", { method: "POST" }), + ); + + if (!response.ok) { + throw new Error(await response.text()); + } +} + async function pruneDeliveries(params: { db: ReturnType; env: Bindings; @@ -183,8 +316,135 @@ async function pruneDeliveries(params: { } } +async function requireSocketDeviceTokenHashFromRequest(request: Request): Promise { + const authHeader = request.headers.get("authorization"); + if (authHeader?.startsWith("Bearer ")) { + const token = authHeader.slice("Bearer ".length).trim(); + if (token) { + return hashToken(token); + } + } + + const token = new URL(request.url).searchParams.get("deviceToken")?.trim(); + if (!token) { + throw new HTTPException(401, { message: "Missing bearer token" }); + } + + return hashToken(token); +} + +function getConnectTunnelId(pathname: string): string | null { + const segments = pathname.split("/").filter(Boolean); + if ( + segments.length === 4 && + segments[0] === "api" && + segments[1] === "tunnels" && + segments[3] === "connect" + ) { + return segments[2] ?? null; + } + + return null; +} + +async function maybeHandleTunnelSocketRequest( + request: Request, + env: Bindings, +): Promise { + const tunnelId = getConnectTunnelId(new URL(request.url).pathname); + if (!tunnelId || request.method !== "GET") { + return null; + } + + const upgradeHeader = request.headers.get("Upgrade"); + if (!upgradeHeader || upgradeHeader.toLowerCase() !== "websocket") { + return new Response("Expected websocket upgrade", { status: 426 }); + } + + if (!isSupportedCliVersion(getCliVersionFromRequest(request))) { + return createCliUpgradeResponse(); + } + + const deviceTokenHash = await requireSocketDeviceTokenHashFromRequest(request); + const db = getDb(env); + const current = await getOwnedTunnel({ db, deviceTokenHash, tunnelId }); + + if (!current) { + return new Response("Tunnel not found", { status: 404 }); + } + + if (current.status === "disabled") { + return new Response("Tunnel disabled", { status: 410 }); + } + + return getTunnelStub(env, current.id).fetch(request); +} + +async function maybeHandleProviderWebhookRequest( + request: Request, + env: Bindings, + ctx: ExecutionContext, +): Promise { + const segments = new URL(request.url).pathname.split("/").filter(Boolean); + if (request.method !== "POST" || segments.length !== 1) { + return null; + } + + const tunnelId = segments[0]; + if (!tunnelId) { + return null; + } + + const db = getDb(env); + const current = await db.select().from(tunnel).where(eq(tunnel.id, tunnelId)).limit(1); + const currentTunnel = current[0]; + + if (!currentTunnel) { + return new Response("Not found", { status: 404 }); + } + + if (currentTunnel.status !== "active") { + return new Response("Tunnel disabled", { status: 410 }); + } + + const body = await request.text(); + const bodyBytes = new TextEncoder().encode(body).byteLength; + if (bodyBytes > getNumericVar(env.MAX_BODY_BYTES, 262_144)) { + return new Response("Payload too large", { status: 413 }); + } + + await db.insert(delivery).values({ + body, + error: null, + failedAt: null, + headers: getRequestHeaders(request), + id: generateId("del"), + method: request.method, + receivedAt: now(), + sentAt: null, + tunnelId: currentTunnel.id, + }); + await pruneDeliveries({ db, env, tunnelId: currentTunnel.id }); + + ctx.waitUntil(notifyTunnelObject(env, { tunnelId: currentTunnel.id })); + + return Response.json({ received: true }); +} + app.get("/api/health", (c) => c.json({ ok: true })); +app.use("/api/*", async (c, next) => { + if (c.req.path === "/api/health") { + return next(); + } + + if (!isSupportedCliVersion(getCliVersion(c))) { + return createCliUpgradeResponse(); + } + + return next(); +}); + app.post("/api/tunnels/ensure", async (c) => { const deviceTokenHash = await requireDeviceTokenHash(c); const db = getDb(c.env); @@ -326,9 +586,10 @@ app.post("/api/tunnels/:tunnelId/provider-webhook", async (c) => { return c.text("providerWebhookEndpointId is required", 400); } + const timestamp = now(); await db .update(tunnel) - .set({ providerWebhookEndpointId: body.providerWebhookEndpointId, updatedAt: now() }) + .set({ providerWebhookEndpointId: body.providerWebhookEndpointId, updatedAt: timestamp }) .where(eq(tunnel.id, current.id)); return c.json({ ok: true }); @@ -447,9 +708,11 @@ app.post("/api/deliveries/:deliveryId/ack", async (c) => { await db .update(delivery) - .set({ deliveredAt: now(), error: null, failedAt: null }) + .set({ deliveredAt: now(), error: null, failedAt: null, sentAt: null }) .where(eq(delivery.id, currentDelivery.id)); + await notifyTunnelObject(c.env, { tunnelId: currentTunnel.id }); + return c.json({ ok: true }); }); @@ -483,9 +746,11 @@ app.post("/api/deliveries/:deliveryId/fail", async (c) => { const body = (await c.req.json()) as { error?: string }; await db .update(delivery) - .set({ error: body.error ?? null, failedAt: now() }) + .set({ error: body.error ?? null, failedAt: now(), sentAt: null }) .where(eq(delivery.id, currentDelivery.id)); + await notifyTunnelObject(c.env, { tunnelId: currentTunnel.id }); + return c.json({ ok: true }); }); @@ -511,42 +776,19 @@ app.post("/api/tunnels/:tunnelId/disable", async (c) => { return c.json({ ok: true }); }); -app.post("/:tunnelId", async (c) => { - const db = getDb(c.env); - const current = await db - .select() - .from(tunnel) - .where(eq(tunnel.id, c.req.param("tunnelId"))) - .limit(1); - const currentTunnel = current[0]; - - if (!currentTunnel) { - return c.text("Not found", 404); - } - - if (currentTunnel.status !== "active") { - return c.text("Tunnel disabled", 410); - } - - const body = await c.req.text(); - const bodyBytes = new TextEncoder().encode(body).byteLength; - if (bodyBytes > getNumericVar(c.env.MAX_BODY_BYTES, 262_144)) { - return c.text("Payload too large", 413); - } - - await db.insert(delivery).values({ - body, - error: null, - failedAt: null, - headers: getRequestHeaders(c.req.raw), - id: generateId("del"), - method: c.req.method, - receivedAt: now(), - tunnelId: currentTunnel.id, - }); - await pruneDeliveries({ db, env: c.env, tunnelId: currentTunnel.id }); +export default { + async fetch(request: Request, env: Bindings, ctx: ExecutionContext) { + const socketResponse = await maybeHandleTunnelSocketRequest(request, env); + if (socketResponse) { + return socketResponse; + } - return c.json({ received: true }); -}); + const providerWebhookResponse = await maybeHandleProviderWebhookRequest(request, env, ctx); + if (providerWebhookResponse) { + return providerWebhookResponse; + } -export default app; + return app.fetch(request, env, ctx); + }, +}; +export { TunnelObject }; diff --git a/apps/wh/src/tunnel-object.ts b/apps/wh/src/tunnel-object.ts new file mode 100644 index 00000000..e620d108 --- /dev/null +++ b/apps/wh/src/tunnel-object.ts @@ -0,0 +1,450 @@ +import { DurableObject } from "cloudflare:workers"; +import { and, asc, count, eq, gte, isNotNull, isNull, lt, or, type SQL } from "drizzle-orm"; +import { drizzle } from "drizzle-orm/d1"; + +import { delivery, tunnel } from "./db/schema"; + +export interface TunnelObjectBindings { + DB: D1Database; + MAX_BODY_BYTES: string; + MAX_DELIVERIES_PER_TUNNEL: string; + PAYKIT_WEBHOOK_API_BASE_URL?: string; + PAYKIT_WEBHOOK_PUBLIC_BASE_URL?: string; + RETENTION_DAYS: string; +} + +interface SocketAttachment { + deviceTokenHash: string; + includeFailedBefore?: number; + replayCompleteSent: boolean; + retryWindowMs: number; + role: "cli"; + sessionId: string; + tunnelId: string; +} + +const REPLACED_SESSION_CLOSE_CODE = 4001; +const REPLACED_SESSION_CLOSE_REASON = "paykit.session_replaced"; + +function now(): number { + return Date.now(); +} + +async function hashToken(token: string): Promise { + const digest = await crypto.subtle.digest("SHA-256", new TextEncoder().encode(token)); + return [...new Uint8Array(digest)].map((byte) => byte.toString(16).padStart(2, "0")).join(""); +} + +function readNumberParam(value: string | null, fallback: number): number { + if (value === null) { + return fallback; + } + + const parsed = Number(value); + return Number.isNaN(parsed) ? fallback : parsed; +} + +function readOptionalNumberParam(value: string | null): number | undefined { + if (value === null) { + return undefined; + } + + const parsed = Number(value); + return Number.isNaN(parsed) ? undefined : parsed; +} + +function readDeviceToken(request: Request): string | null { + const authHeader = request.headers.get("authorization"); + if (authHeader?.startsWith("Bearer ")) { + const token = authHeader.slice("Bearer ".length).trim(); + if (token) { + return token; + } + } + + const token = new URL(request.url).searchParams.get("deviceToken")?.trim(); + return token ? token : null; +} + +function buildDeliverableWhere(params: { + includeFailedBefore?: number; + retryWindowMs: number; + tunnelId: string; +}): SQL | undefined { + const conditions = [ + eq(delivery.tunnelId, params.tunnelId), + isNull(delivery.deliveredAt), + isNull(delivery.sentAt), + isNull(delivery.failedAt), + ]; + + if (params.retryWindowMs > 0 && typeof params.includeFailedBefore === "number") { + conditions[3] = or( + isNull(delivery.failedAt), + and( + lt(delivery.failedAt, params.includeFailedBefore), + gte(delivery.receivedAt, now() - params.retryWindowMs), + ), + )!; + } + + return and(...conditions); +} + +export class TunnelObject extends DurableObject { + private db: ReturnType; + + constructor(ctx: DurableObjectState, env: TunnelObjectBindings) { + super(ctx, env); + this.db = drizzle(env.DB); + } + + async fetch(request: Request): Promise { + const url = new URL(request.url); + + if (url.pathname === "/internal/push" && request.method === "POST") { + await this.pushToConnectedClient(); + return Response.json({ ok: true }); + } + + if (request.headers.get("Upgrade") === "websocket") { + return this.handleSocketConnect(request); + } + + return new Response("Not found", { status: 404 }); + } + + async webSocketMessage(ws: WebSocket, message: ArrayBuffer | string): Promise { + const attachment = this.readSocketAttachment(ws); + if (!attachment) { + ws.close(1008, "invalid socket state"); + return; + } + + if (typeof message !== "string") { + ws.close(1003, "expected text message"); + return; + } + + let parsed: { deliveryId?: string; error?: string; type?: string }; + try { + parsed = JSON.parse(message) as { deliveryId?: string; error?: string; type?: string }; + } catch { + ws.close(1003, "invalid message"); + return; + } + + switch (parsed.type) { + case "ack": + if (!parsed.deliveryId) { + ws.close(1008, "deliveryId is required"); + return; + } + await this.ackDelivery({ deliveryId: parsed.deliveryId, tunnelId: attachment.tunnelId }); + await this.sendNextDelivery(ws); + return; + case "fail": + if (!parsed.deliveryId) { + ws.close(1008, "deliveryId is required"); + return; + } + await this.failDelivery({ + deliveryId: parsed.deliveryId, + error: parsed.error ?? "failed", + tunnelId: attachment.tunnelId, + }); + await this.sendNextDelivery(ws); + return; + case "ping": + ws.send(JSON.stringify({ type: "pong" })); + return; + default: + ws.close(1003, "unsupported message"); + } + } + + async webSocketClose(ws: WebSocket): Promise { + await this.resetInFlightDeliveriesForActiveSocket(ws); + } + + async webSocketError(ws: WebSocket): Promise { + await this.resetInFlightDeliveriesForActiveSocket(ws); + } + + private async resetInFlightDeliveriesForActiveSocket(ws: WebSocket): Promise { + const attachment = this.readSocketAttachment(ws); + if (!attachment) { + return; + } + + const activeSessionId = await this.ctx.storage.get( + this.activeSessionKey(attachment.tunnelId), + ); + if (activeSessionId !== attachment.sessionId) { + return; + } + + await this.resetInFlightDeliveries(attachment.tunnelId); + } + + private async handleSocketConnect(request: Request): Promise { + const tunnelId = this.extractTunnelIdFromConnectPath(request.url); + if (!tunnelId) { + return new Response("Tunnel not found", { status: 404 }); + } + + const token = readDeviceToken(request); + if (!token) { + return new Response("Missing bearer token", { status: 401 }); + } + + const deviceTokenHash = await hashToken(token); + const rows = await this.db + .select() + .from(tunnel) + .where(and(eq(tunnel.id, tunnelId), eq(tunnel.deviceTokenHash, deviceTokenHash))) + .limit(1); + const currentTunnel = rows[0]; + + if (!currentTunnel) { + return new Response("Tunnel not found", { status: 404 }); + } + + if (currentTunnel.status !== "active") { + return new Response("Tunnel disabled", { status: 410 }); + } + + const url = new URL(request.url); + const retryWindowMs = Math.max(0, readNumberParam(url.searchParams.get("retryWindowMs"), 0)); + const includeFailedBefore = readOptionalNumberParam( + url.searchParams.get("includeFailedBefore"), + ); + + const pair = new WebSocketPair(); + const [client, server] = Object.values(pair) as [WebSocket, WebSocket]; + const sessionId = crypto.randomUUID(); + + await this.ctx.storage.put(this.activeSessionKey(tunnelId), sessionId); + await this.resetInFlightDeliveries(tunnelId); + this.closeClientSockets(); + this.ctx.acceptWebSocket(server, ["cli"]); + server.serializeAttachment({ + deviceTokenHash, + includeFailedBefore, + replayCompleteSent: false, + retryWindowMs, + role: "cli", + sessionId, + tunnelId, + } satisfies SocketAttachment); + + server.send( + JSON.stringify({ + pendingCount: await this.countDeliverableDeliveries({ + includeFailedBefore, + retryWindowMs, + tunnelId, + }), + tunnelId, + type: "hello", + }), + ); + await this.sendNextDelivery(server); + + return new Response(null, { status: 101, webSocket: client }); + } + + private readSocketAttachment(ws: WebSocket): SocketAttachment | null { + const attachment = ws.deserializeAttachment(); + if (!attachment || typeof attachment !== "object") { + return null; + } + + const socketAttachment = attachment as Partial; + if (socketAttachment.role !== "cli" || typeof socketAttachment.tunnelId !== "string") { + return null; + } + + return { + deviceTokenHash: socketAttachment.deviceTokenHash ?? "", + includeFailedBefore: socketAttachment.includeFailedBefore, + replayCompleteSent: socketAttachment.replayCompleteSent ?? false, + retryWindowMs: socketAttachment.retryWindowMs ?? 0, + role: "cli", + sessionId: socketAttachment.sessionId ?? "", + tunnelId: socketAttachment.tunnelId, + }; + } + + private async countDeliverableDeliveries(params: { + includeFailedBefore?: number; + retryWindowMs: number; + tunnelId: string; + }): Promise { + const rows = await this.db + .select({ count: count() }) + .from(delivery) + .where(buildDeliverableWhere(params)); + return rows[0]?.count ?? 0; + } + + private async hasInFlightDelivery(tunnelId: string): Promise { + const rows = await this.db + .select({ count: count() }) + .from(delivery) + .where( + and( + eq(delivery.tunnelId, tunnelId), + isNull(delivery.deliveredAt), + isNull(delivery.failedAt), + isNotNull(delivery.sentAt), + ), + ); + return (rows[0]?.count ?? 0) > 0; + } + + private async sendNextDelivery(ws: WebSocket): Promise { + const attachment = this.readSocketAttachment(ws); + if (!attachment) { + return; + } + + if (await this.hasInFlightDelivery(attachment.tunnelId)) { + return; + } + + const rows = await this.db + .select() + .from(delivery) + .where( + buildDeliverableWhere({ + includeFailedBefore: attachment.includeFailedBefore, + retryWindowMs: attachment.retryWindowMs, + tunnelId: attachment.tunnelId, + }), + ) + .orderBy(asc(delivery.receivedAt), asc(delivery.id)) + .limit(1); + const nextDelivery = rows[0]; + + if (!nextDelivery) { + if (!attachment.replayCompleteSent) { + attachment.replayCompleteSent = true; + ws.serializeAttachment(attachment); + ws.send(JSON.stringify({ type: "replay_complete" })); + } + return; + } + + const claimed = await this.db + .update(delivery) + .set({ sentAt: now() }) + .where( + and( + eq(delivery.id, nextDelivery.id), + eq(delivery.tunnelId, attachment.tunnelId), + isNull(delivery.deliveredAt), + isNull(delivery.sentAt), + ), + ) + .returning({ id: delivery.id }); + + if (claimed.length !== 1) { + return; + } + + try { + ws.send( + JSON.stringify({ + delivery: { + body: nextDelivery.body, + headers: nextDelivery.headers, + id: nextDelivery.id, + method: nextDelivery.method, + receivedAt: new Date(nextDelivery.receivedAt).toISOString(), + }, + type: "delivery", + }), + ); + } catch (error) { + await this.db.update(delivery).set({ sentAt: null }).where(eq(delivery.id, nextDelivery.id)); + throw new Error("Failed to send delivery", { cause: error }); + } + } + + private async pushToConnectedClient(): Promise { + const ws = this.ctx.getWebSockets("cli")[0]; + if (!ws) { + return; + } + + await this.sendNextDelivery(ws); + } + + private async ackDelivery(params: { deliveryId: string; tunnelId: string }): Promise { + const acked = await this.db + .update(delivery) + .set({ deliveredAt: now(), error: null, failedAt: null, sentAt: null }) + .where( + and( + eq(delivery.id, params.deliveryId), + eq(delivery.tunnelId, params.tunnelId), + isNull(delivery.deliveredAt), + isNull(delivery.failedAt), + isNotNull(delivery.sentAt), + ), + ) + .returning({ id: delivery.id }); + + if (acked.length !== 1) { + return; + } + } + + private async failDelivery(params: { + deliveryId: string; + error: string; + tunnelId: string; + }): Promise { + await this.db + .update(delivery) + .set({ error: params.error, failedAt: now(), sentAt: null }) + .where(and(eq(delivery.id, params.deliveryId), eq(delivery.tunnelId, params.tunnelId))); + } + + private async resetInFlightDeliveries(tunnelId: string): Promise { + await this.db + .update(delivery) + .set({ sentAt: null }) + .where( + and( + eq(delivery.tunnelId, tunnelId), + isNull(delivery.deliveredAt), + isNull(delivery.failedAt), + ), + ); + } + + private closeClientSockets(): void { + for (const socket of this.ctx.getWebSockets("cli")) { + socket.close(REPLACED_SESSION_CLOSE_CODE, REPLACED_SESSION_CLOSE_REASON); + } + } + + private activeSessionKey(tunnelId: string): string { + return `active-session:${tunnelId}`; + } + + private extractTunnelIdFromConnectPath(urlValue: string): string | null { + const segments = new URL(urlValue).pathname.split("/").filter(Boolean); + if ( + segments.length === 4 && + segments[0] === "api" && + segments[1] === "tunnels" && + segments[3] === "connect" + ) { + return segments[2] ?? null; + } + return null; + } +} diff --git a/apps/wh/wrangler.jsonc b/apps/wh/wrangler.jsonc index a3562e8f..1548095a 100644 --- a/apps/wh/wrangler.jsonc +++ b/apps/wh/wrangler.jsonc @@ -3,11 +3,25 @@ "name": "paykit-wh", "main": "src/index.ts", "compatibility_date": "2026-05-03", + "durable_objects": { + "bindings": [ + { + "name": "TUNNEL_OBJECT", + "class_name": "TunnelObject" + } + ] + }, + "migrations": [ + { + "tag": "v1", + "new_sqlite_classes": ["TunnelObject"] + } + ], "d1_databases": [ { "binding": "DB", - "database_name": "paykit-wh", - "database_id": "1d9fdd66-9eb5-43fe-ace2-843abce58bf4" + "database_name": "paykit-wh-enam", + "database_id": "0958661f-897b-45db-8782-16e2bca5aa36" } ], "vars": { diff --git a/packages/paykit/src/cli/commands/listen.ts b/packages/paykit/src/cli/commands/listen.ts index c6495801..125554e7 100644 --- a/packages/paykit/src/cli/commands/listen.ts +++ b/packages/paykit/src/cli/commands/listen.ts @@ -10,13 +10,14 @@ import { getPayKitConfig } from "../utils/get-config"; import { capture } from "../utils/telemetry"; const DEFAULT_CLOUD_BASE_URL = "https://wh.paykit.sh"; -const DEFAULT_URL = "http://localhost:3000"; -const DEFAULT_BATCH_SIZE = 30; const DEFAULT_ERROR_BACKOFF_MS = 2_000; const MAX_ERROR_BACKOFF_MS = 15_000; -const DEFAULT_POLL_INTERVAL_MS = 2_000; const DEFAULT_RETRY_WINDOW = "5m"; +const CLI_VERSION = "0.0.4"; +const STABLE_SOCKET_RESET_MS = 30_000; +const FORWARD_REPLAY_TIMEOUT_MS = 5_000; const REPLAY_HEADER = "x-paykit-cloud-replay"; +const REPLACED_SESSION_CLOSE_CODE = 4001; interface TunnelResponse { found: boolean; @@ -62,11 +63,19 @@ interface ReplayResult { status?: number; } +type DeliveryMode = "direct" | "forward"; + interface DeliveryDetails { eventId?: string; eventType?: string; } +type TunnelServerMessage = + | { pendingCount: number; tunnelId: string; type: "hello" } + | { delivery: DeliveryResponse; type: "delivery" } + | { type: "pong" } + | { type: "replay_complete" }; + interface RelayRuntimeContext { account: TunnelAccountSummary; config: Awaited>; @@ -108,7 +117,7 @@ function parseRetryWindowMs(value: string): number { function normalizeLocalOrigin(url: string): string { const parsed = new URL(url); if (parsed.pathname !== "/" || parsed.search || parsed.hash) { - throw new Error(`--url must be an origin only, received "${url}"`); + throw new Error(`--forward-to must be an origin only, received "${url}"`); } return parsed.origin; @@ -150,7 +159,8 @@ function printReadyBlock( devLogger: ReturnType, params: { account: TunnelAccountSummary; - localWebhookUrl: string; + deliveryMode: DeliveryMode; + localWebhookUrl?: string; webhookSecret?: string; webhookUrl: string; }, @@ -168,7 +178,9 @@ function printReadyBlock( : ""; devLogger.print( - `Webhooks forwarding to ${picocolors.cyan(params.localWebhookUrl)}\n\n` + + (params.deliveryMode === "forward" && params.localWebhookUrl + ? `Webhooks forwarding to ${picocolors.cyan(params.localWebhookUrl)}\n\n` + : "Webhooks forwarding directly to your PayKit instance\n\n") + `${bullet} ${providerLabel} ${accountSummary}\n` + `${bullet} ${endpointLabel} ${params.webhookUrl}\n` + `${bullet} ${secretLabel} ${params.webhookSecret ?? picocolors.dim("(existing secret hidden)")}${reminder}\n` + @@ -250,23 +262,39 @@ async function requestCloud( ): Promise { const headers = new Headers(init.headers); headers.set("authorization", `Bearer ${deviceToken}`); + headers.set("x-paykit-cli-version", CLI_VERSION); if (init.body && !headers.has("content-type")) { headers.set("content-type", "application/json"); } - const cloudBaseUrl = - process.env.PAYKIT_WEBHOOK_API_BASE_URL ?? - process.env.PAYKIT_CLOUD_URL ?? - DEFAULT_CLOUD_BASE_URL; + const cloudBaseUrl = getCloudBaseUrl(); - const response = await fetch(`${cloudBaseUrl}${pathname}`, { - ...init, - headers, - }); + let response: Response; + try { + response = await fetch(`${cloudBaseUrl}${pathname}`, { + ...init, + headers, + }); + } catch (error) { + throw new Error( + `Could not connect to the PayKit webhook server at ${cloudBaseUrl}. Is the worker running?`, + { cause: error }, + ); + } if (!response.ok) { const contentType = response.headers.get("content-type") ?? ""; const body = await response.text(); + if (response.status === 426) { + let message = body || "This paykitjs CLI version is no longer supported."; + try { + const parsed = JSON.parse(body) as { message?: string }; + message = parsed.message ?? message; + } catch { + // Non-JSON upgrade responses can still carry a useful text body. + } + throw new Error(message); + } const message = contentType.includes("text/html") ? `PayKit server request failed (${response.status} ${response.statusText})` : body || `PayKit server request failed (${response.status} ${response.statusText})`; @@ -276,6 +304,183 @@ async function requestCloud( return (await response.json()) as T; } +function getCloudBaseUrl(): string { + return ( + process.env.PAYKIT_CLOUD_URL ?? + process.env.PAYKIT_WEBHOOK_API_BASE_URL ?? + DEFAULT_CLOUD_BASE_URL + ); +} + +function buildTunnelSocketUrl(params: { + deviceToken: string; + includeFailedBefore?: number; + retryWindowMs: number; + tunnelId: string; +}): string { + const cloudUrl = new URL(getCloudBaseUrl()); + cloudUrl.protocol = cloudUrl.protocol === "https:" ? "wss:" : "ws:"; + cloudUrl.pathname = `/api/tunnels/${params.tunnelId}/connect`; + cloudUrl.search = ""; + cloudUrl.searchParams.set("deviceToken", params.deviceToken); + cloudUrl.searchParams.set("cliVersion", CLI_VERSION); + cloudUrl.searchParams.set("retryWindowMs", String(params.retryWindowMs)); + if (typeof params.includeFailedBefore === "number") { + cloudUrl.searchParams.set("includeFailedBefore", String(params.includeFailedBefore)); + } + return cloudUrl.toString(); +} + +async function connectTunnelSocket(params: { + deviceToken: string; + includeFailedBefore?: number; + retryWindowMs: number; + tunnelId: string; +}): Promise { + const socket = new WebSocket( + buildTunnelSocketUrl({ + deviceToken: params.deviceToken, + includeFailedBefore: params.includeFailedBefore, + retryWindowMs: params.retryWindowMs, + tunnelId: params.tunnelId, + }), + ); + + await new Promise((resolve, reject) => { + const onOpen = () => { + cleanup(); + resolve(); + }; + const onError = () => { + cleanup(); + reject(new Error("websocket connection failed")); + }; + const onClose = (event: CloseEvent) => { + cleanup(); + reject(new Error(`websocket closed (${event.code})`)); + }; + const cleanup = () => { + socket.removeEventListener("open", onOpen); + socket.removeEventListener("error", onError); + socket.removeEventListener("close", onClose); + }; + + socket.addEventListener("open", onOpen); + socket.addEventListener("error", onError); + socket.addEventListener("close", onClose); + }); + + return socket; +} + +async function consumeTunnelSocket(params: { + config: Awaited>; + devLogger: ReturnType; + forwardTo?: string; + onReplayComplete: () => void; + socket: WebSocket; +}): Promise<{ code?: number; reason?: string }> { + return new Promise<{ code?: number; reason?: string }>((resolve, reject) => { + let settled = false; + let replayCompleteSeen = false; + let processing = Promise.resolve(); + + const cleanup = () => { + params.socket.removeEventListener("close", onClose); + params.socket.removeEventListener("error", onError); + params.socket.removeEventListener("message", onMessage); + }; + + const settle = (callback: () => void) => { + if (settled) { + return; + } + + settled = true; + cleanup(); + callback(); + }; + + const onClose = (event: CloseEvent) => { + processing.finally(() => + settle(() => resolve({ code: event.code, reason: event.reason || undefined })), + ); + }; + const onError = () => { + processing.finally(() => settle(() => reject(new Error("websocket stream failed")))); + }; + const onMessage = (event: MessageEvent) => { + processing = processing.then(async () => { + const data = typeof event.data === "string" ? event.data : String(event.data); + const message = JSON.parse(data) as TunnelServerMessage; + + switch (message.type) { + case "delivery": { + const result = await deliverWebhook({ + config: params.config, + delivery: message.delivery, + forwardTo: params.forwardTo, + }); + const details = parseDeliveryDetails(message.delivery.body); + const eventId = details.eventId ?? message.delivery.id; + const eventType = details.eventType ?? "unknown"; + + if (!result.ok) { + const statusLabel = result.error ?? String(result.status ?? "failed"); + params.socket.send( + JSON.stringify({ + deliveryId: message.delivery.id, + error: statusLabel, + type: "fail", + }), + ); + params.devLogger.event({ + eventId, + eventType, + replay: !replayCompleteSeen, + status: statusLabel, + }); + return; + } + + params.socket.send(JSON.stringify({ deliveryId: message.delivery.id, type: "ack" })); + params.devLogger.event({ + eventId, + eventType, + replay: !replayCompleteSeen, + status: result.status ?? 200, + }); + return; + } + case "replay_complete": + replayCompleteSeen = true; + params.onReplayComplete(); + return; + case "hello": + case "pong": + return; + default: + throw new Error( + `Unsupported websocket message type: ${(message as { type?: string }).type}`, + ); + } + }); + processing.catch((error) => { + settle(() => reject(error)); + try { + params.socket.close(); + } catch { + // ignore close failures while unwinding the socket loop + } + }); + }; + + params.socket.addEventListener("close", onClose); + params.socket.addEventListener("error", onError); + params.socket.addEventListener("message", onMessage); + }); +} + async function ensureTunnel(params: { account: TunnelAccountSummary; createIfMissing: boolean; @@ -315,32 +520,6 @@ async function ackDelivery(params: { deliveryId: string; deviceToken: string }): }); } -async function pullDeliveries(params: { - deviceToken: string; - includeFailedBefore?: number; - limit: number; - offset?: number; - retryWindowMs: number; - tunnelId: string; -}): Promise { - const search = new URLSearchParams({ - limit: String(params.limit), - retryWindowMs: String(params.retryWindowMs), - }); - if (typeof params.includeFailedBefore === "number") { - search.set("includeFailedBefore", String(params.includeFailedBefore)); - } - if (params.offset) { - search.set("offset", String(params.offset)); - } - - const response = await requestCloud<{ deliveries: DeliveryResponse[] }>( - params.deviceToken, - `/api/tunnels/${params.tunnelId}/pull?${search.toString()}`, - ); - return response.deliveries; -} - async function getDelivery(params: { deliveryId: string; deviceToken: string; @@ -362,12 +541,14 @@ async function failDelivery(params: { async function replayDelivery(params: { delivery: DeliveryResponse; localWebhookUrl: string; + signal?: AbortSignal; }): Promise { try { const response = await fetch(params.localWebhookUrl, { body: params.delivery.body, headers: sanitizeReplayHeaders(params.delivery.headers), method: params.delivery.method, + signal: params.signal, }); return { ok: response.ok, status: response.status }; @@ -376,6 +557,61 @@ async function replayDelivery(params: { } } +async function replayDeliveryWithTimeout(params: { + delivery: DeliveryResponse; + localWebhookUrl: string; + timeoutMs: number; +}): Promise { + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), params.timeoutMs); + + try { + const result = await replayDelivery({ + delivery: params.delivery, + localWebhookUrl: params.localWebhookUrl, + signal: controller.signal, + }); + if (!result.ok && controller.signal.aborted) { + return { error: `forward-to timeout: ${params.localWebhookUrl}`, ok: false }; + } + return result; + } finally { + clearTimeout(timeout); + } +} + +async function applyDeliveryDirectly(params: { + config: Awaited>; + delivery: DeliveryResponse; +}): Promise { + try { + await params.config.paykit.handleWebhook({ + allowStaleSignatures: true, + body: params.delivery.body, + headers: params.delivery.headers, + }); + return { ok: true, status: 200 }; + } catch (error) { + return { error: error instanceof Error ? error.message : String(error), ok: false }; + } +} + +async function deliverWebhook(params: { + config: Awaited>; + delivery: DeliveryResponse; + forwardTo?: string; +}): Promise { + if (params.forwardTo) { + return replayDeliveryWithTimeout({ + delivery: params.delivery, + localWebhookUrl: params.forwardTo, + timeoutMs: FORWARD_REPLAY_TIMEOUT_MS, + }); + } + + return applyDeliveryDirectly({ config: params.config, delivery: params.delivery }); +} + async function syncProviderWebhook(params: { deviceToken: string; provider: TunnelCapableProvider; @@ -397,62 +633,14 @@ async function syncProviderWebhook(params: { return { webhookSecret: providerWebhook.webhookSecret }; } -async function processPendingDeliveries(params: { - devLogger: ReturnType; - deliveries: DeliveryResponse[]; - deviceToken: string; - localWebhookUrl: string; - mode: "live" | "replay"; -}): Promise<{ - hadDeliveries: boolean; - processedCount: number; -}> { - const deliveries = params.deliveries; - - if (deliveries.length === 0) { - return { hadDeliveries: false, processedCount: 0 }; - } - - for (const delivery of deliveries) { - const result = await replayDelivery({ delivery, localWebhookUrl: params.localWebhookUrl }); - const details = parseDeliveryDetails(delivery.body); - const eventId = details.eventId ?? delivery.id; - const eventType = details.eventType ?? "unknown"; - - if (!result.ok) { - const statusLabel = result.error ?? String(result.status ?? "failed"); - await failDelivery({ - deliveryId: delivery.id, - deviceToken: params.deviceToken, - error: statusLabel, - }); - - params.devLogger.event({ - eventId, - eventType, - replay: params.mode === "replay", - status: statusLabel, - }); - continue; - } - - params.devLogger.event({ - eventId, - eventType, - replay: params.mode === "replay", - status: result.status ?? 200, - }); - - await ackDelivery({ deliveryId: delivery.id, deviceToken: params.deviceToken }); - } - - return { hadDeliveries: true, processedCount: deliveries.length }; -} - function getNextErrorBackoff(currentMs: number): number { return currentMs === 0 ? DEFAULT_ERROR_BACKOFF_MS : Math.min(currentMs * 2, MAX_ERROR_BACKOFF_MS); } +function isReplacedSessionClose(close: { code?: number; reason?: string }): boolean { + return close.code === REPLACED_SESSION_CLOSE_CODE; +} + async function loadRelayRuntimeContext(params: { configPath?: string; cwd: string; @@ -478,13 +666,12 @@ async function loadRelayRuntimeContext(params: { async function listenAction(options: { config?: string; cwd: string; + forwardTo?: string; retry: string; - url: string; }): Promise { const cwd = path.resolve(options.cwd); capture("cli_command", { command: "listen" }); const devLogger = createDevLogger(); - const localOrigin = normalizeLocalOrigin(options.url); const retryWindowMs = parseRetryWindowMs(options.retry); const relayStartedAt = Date.now(); @@ -509,10 +696,16 @@ async function listenAction(options: { devLogger.update("Ensuring webhook endpoint"); const { webhookSecret } = await syncProviderWebhook({ deviceToken, provider, tunnel }); - const localWebhookUrl = buildLocalWebhookUrl(localOrigin, config.options.basePath ?? "/paykit"); + const localWebhookUrl = options.forwardTo + ? buildLocalWebhookUrl( + normalizeLocalOrigin(options.forwardTo), + config.options.basePath ?? "/paykit", + ) + : undefined; devLogger.stop(); printReadyBlock(devLogger, { account, + deliveryMode: localWebhookUrl ? "forward" : "direct", localWebhookUrl, webhookSecret, webhookUrl: tunnel.webhookUrl, @@ -524,36 +717,49 @@ async function listenAction(options: { ); } - let mode: "live" | "replay" = "replay"; let errorBackoffMs = 0; + let replayCompleteLogged = false; for (;;) { try { - const deliveries = await pullDeliveries({ + const socketConnectedAt = Date.now(); + const socket = await connectTunnelSocket({ deviceToken, - includeFailedBefore: mode === "replay" ? relayStartedAt : undefined, - limit: DEFAULT_BATCH_SIZE, - retryWindowMs: mode === "replay" ? retryWindowMs : 0, + includeFailedBefore: relayStartedAt, + retryWindowMs, tunnelId: tunnel.tunnelId, }); - const result = await processPendingDeliveries({ + const close = await consumeTunnelSocket({ + config, devLogger, - deliveries, - deviceToken, - localWebhookUrl, - mode, + forwardTo: localWebhookUrl, + onReplayComplete: () => { + if (!replayCompleteLogged) { + replayCompleteLogged = true; + devLogger.info("replay complete, listening for new webhooks"); + } + }, + socket, }); - errorBackoffMs = 0; - - if (!result.hadDeliveries && mode === "replay") { - devLogger.info("replay complete, listening for new webhooks"); - mode = "live"; - continue; + if (Date.now() - socketConnectedAt >= STABLE_SOCKET_RESET_MS) { + errorBackoffMs = 0; + } + const closeLabel = close.reason + ? `${String(close.code ?? "unknown")} ${close.reason}` + : String(close.code ?? "unknown"); + + if (isReplacedSessionClose(close)) { + devLogger.warn( + "Another paykitjs listen session connected for this tunnel. Stopping this older session.", + ); + return; } - await sleep(result.processedCount > 0 ? 250 : DEFAULT_POLL_INTERVAL_MS); + devLogger.warn(`Listen connection closed: ${closeLabel}`); + errorBackoffMs = getNextErrorBackoff(errorBackoffMs); + await sleep(errorBackoffMs); } catch (error) { const message = error instanceof Error ? error.message : String(error); devLogger.warn(`Listen loop failed: ${message}`); @@ -563,13 +769,12 @@ async function listenAction(options: { } } -async function enableAction(options: { config?: string; cwd: string; url: string }): Promise { +async function enableAction(options: { config?: string; cwd: string }): Promise { const cwd = path.resolve(options.cwd); capture("cli_command", { command: "listen_enable" }); const devLogger = createDevLogger(); - const localOrigin = normalizeLocalOrigin(options.url); - const { account, config, deviceToken, provider } = await loadRelayRuntimeContext({ + const { account, deviceToken, provider } = await loadRelayRuntimeContext({ configPath: options.config, cwd, devLogger, @@ -589,7 +794,6 @@ async function enableAction(options: { config?: string; cwd: string; url: string devLogger.update("Ensuring webhook endpoint"); const { webhookSecret } = await syncProviderWebhook({ deviceToken, provider, tunnel }); - buildLocalWebhookUrl(localOrigin, config.options.basePath ?? "/paykit"); devLogger.stop(); printEnableSummary(devLogger, { account, @@ -641,24 +845,28 @@ async function retryAction(options: { config?: string; cwd: string; deliveryId: string; - url: string; + forwardTo?: string; }): Promise { const cwd = path.resolve(options.cwd); capture("cli_command", { command: "listen_retry" }); const devLogger = createDevLogger(); - const localOrigin = normalizeLocalOrigin(options.url); const { config, deviceToken } = await loadRelayRuntimeContext({ configPath: options.config, cwd, devLogger, }); - const localWebhookUrl = buildLocalWebhookUrl(localOrigin, config.options.basePath ?? "/paykit"); + const forwardTo = options.forwardTo + ? buildLocalWebhookUrl( + normalizeLocalOrigin(options.forwardTo), + config.options.basePath ?? "/paykit", + ) + : undefined; const delivery = await getDelivery({ deliveryId: options.deliveryId, deviceToken }); devLogger.stop(); const details = parseDeliveryDetails(delivery.body); - const result = await replayDelivery({ delivery, localWebhookUrl }); + const result = await deliverWebhook({ config, delivery, forwardTo }); if (!result.ok) { const statusLabel = result.error ?? String(result.status ?? "failed"); await failDelivery({ deliveryId: delivery.id, deviceToken, error: statusLabel }); @@ -688,25 +896,25 @@ async function retryAction(options: { } function mergeRelaySubcommandOptions< - TOptions extends { config?: string; cwd?: string; retry?: string; url?: string }, + TOptions extends { config?: string; cwd?: string; forwardTo?: string; retry?: string }, >( options: TOptions, command: Command, -): { config?: string; cwd: string; retry?: string; url: string } { +): { config?: string; cwd: string; forwardTo?: string; retry?: string } { const parentOptions = command.parent?.opts() as - | { config?: string; cwd?: string; retry?: string; url?: string } + | { config?: string; cwd?: string; forwardTo?: string; retry?: string } | undefined; return { config: options.config ?? parentOptions?.config, cwd: options.cwd ?? parentOptions?.cwd ?? process.cwd(), + forwardTo: options.forwardTo ?? parentOptions?.forwardTo, retry: options.retry ?? parentOptions?.retry, - url: options.url ?? parentOptions?.url ?? DEFAULT_URL, }; } export const listenCommand = new Command("listen") - .description("Register a provider webhook tunnel, replay missed events, and keep polling") + .description("Register a provider webhook tunnel, replay missed events, and stream new webhooks") .option( "-c, --cwd ", "the working directory. defaults to the current directory.", @@ -718,7 +926,10 @@ export const listenCommand = new Command("listen") "retry failed deliveries received within this window", DEFAULT_RETRY_WINDOW, ) - .option("--url ", "local app origin", DEFAULT_URL) + .option( + "--forward-to ", + "forward webhooks to a local app origin instead of applying directly", + ) .action(listenAction) .addCommand( new Command("enable") @@ -729,7 +940,6 @@ export const listenCommand = new Command("listen") process.cwd(), ) .option("--config ", "the path to the PayKit configuration file to load.") - .option("--url ", "local app origin") .action((options, command) => enableAction(mergeRelaySubcommandOptions(options, command))), ) .addCommand( @@ -742,7 +952,10 @@ export const listenCommand = new Command("listen") process.cwd(), ) .option("--config ", "the path to the PayKit configuration file to load.") - .option("--url ", "local app origin") + .option( + "--forward-to ", + "forward webhook to a local app origin instead of applying directly", + ) .action((deliveryId, options, command) => retryAction({ ...mergeRelaySubcommandOptions(options, command), diff --git a/packages/paykit/src/cli/utils/get-config.ts b/packages/paykit/src/cli/utils/get-config.ts index d1f178cb..c2be1732 100644 --- a/packages/paykit/src/cli/utils/get-config.ts +++ b/packages/paykit/src/cli/utils/get-config.ts @@ -132,19 +132,27 @@ async function loadModule(cwd: string, configPath: string): Promise { return jiti.import(configPath); } -function getPayKit(moduleValue: unknown) { +type ConfiguredPayKit = { + handleWebhook(input: { + allowStaleSignatures?: boolean; + body: string; + headers: Record; + }): Promise; + options: PayKitOptions; +}; + +function getPayKit(moduleValue: unknown): ConfiguredPayKit | null { if (!moduleValue || typeof moduleValue !== "object") return null; const moduleObject = moduleValue as Record; return ( [moduleObject.paykit, moduleObject.default].find( - (value): value is { options: PayKitOptions } => - isPayKitInstance(value) || isPayKitLike(value), + (value): value is ConfiguredPayKit => isPayKitInstance(value) || isPayKitLike(value), ) ?? null ); } -function isPayKitLike(value: unknown): value is { options: PayKitOptions } { +function isPayKitLike(value: unknown): value is ConfiguredPayKit { if (!value || typeof value !== "object") return false; const paykit = value as Record; @@ -187,6 +195,7 @@ async function loadConfiguredPayKit(cwd: string, resolvedPath: string) { return { path: resolvedPath, + paykit, options: paykit.options, }; } diff --git a/packages/paykit/src/core/__tests__/logger.test.ts b/packages/paykit/src/core/__tests__/logger.test.ts index a2d3a489..b0e60bd2 100644 --- a/packages/paykit/src/core/__tests__/logger.test.ts +++ b/packages/paykit/src/core/__tests__/logger.test.ts @@ -49,6 +49,7 @@ async function flushLogger(logger: pino.Logger): Promise { describe("core/logger", () => { afterEach(() => { delete process.env.NODE_ENV; + delete process.env.PAYKIT_CLI; }); it("enables pretty logs for all non-production environments", () => { @@ -65,12 +66,42 @@ describe("core/logger", () => { expect(options.timestamp).toBeTypeOf("function"); expect(getPrettyLoggerOptions()).toEqual({ colorize: true, + colorizeObjects: false, customPrettifiers: { + actionType: expect.any(Function), + duration: expect.any(Function), + event: expect.any(Function), + msg: expect.any(Function), + providerEventId: expect.any(Function), time: expect.any(Function), + traceId: expect.any(Function), }, ignore: "pid,hostname", - levelFirst: true, - translateTime: "SYS:HH:MM:ss.l", + levelFirst: false, + messageFormat: expect.any(Function), + translateTime: "SYS:HH:MM:ss", + }); + }); + + it("hides the logger name in CLI pretty logs", () => { + process.env.PAYKIT_CLI = "1"; + + expect(getPrettyLoggerOptions()).toEqual({ + colorize: true, + colorizeObjects: false, + customPrettifiers: { + actionType: expect.any(Function), + duration: expect.any(Function), + event: expect.any(Function), + msg: expect.any(Function), + providerEventId: expect.any(Function), + time: expect.any(Function), + traceId: expect.any(Function), + }, + ignore: "pid,hostname,name,traceId", + levelFirst: false, + messageFormat: expect.any(Function), + translateTime: "SYS:HH:MM:ss", }); }); diff --git a/packages/paykit/src/core/logger.ts b/packages/paykit/src/core/logger.ts index 6fdf7d86..c8663d38 100644 --- a/packages/paykit/src/core/logger.ts +++ b/packages/paykit/src/core/logger.ts @@ -1,4 +1,5 @@ import { AsyncLocalStorage } from "node:async_hooks"; +import { Transform } from "node:stream"; import pino from "pino"; import pretty from "pino-pretty"; @@ -8,8 +9,37 @@ import { generateId } from "./utils"; const storage = new AsyncLocalStorage(); const PRETTY_LOG_IGNORE_FIELDS = "pid,hostname"; -const PRETTY_LOG_TIMESTAMP = "SYS:HH:MM:ss.l"; +const PRETTY_LOG_TIMESTAMP = "SYS:HH:MM:ss"; const DEFAULT_LOG_LEVEL = "info"; +const dim = (value: unknown) => `\x1b[2m${String(value)}\x1b[0m`; +const plain = (value: unknown) => `\x1b[39m${String(value)}\x1b[39m`; +const DETAIL_LINE_PATTERN = /(^|\n)(\s+[^\n]+)/g; + +/** + * Dims pretty-log detail lines matched by DETAIL_LINE_PATTERN. + * @param input The formatted pretty-log output. + * @returns The input with detail lines dimmed. + */ +function dimDetailLines(input: string): string { + return input.replace(DETAIL_LINE_PATTERN, (_match, prefix: string, line: string) => { + return `${prefix}${dim(line)}`; + }); +} + +/** + * Creates a pretty logger stream that dims detail lines before stdout. + * @returns The configured pretty logger stream. + */ +function createPrettyStream() { + const output = new Transform({ + transform(chunk, _encoding, callback) { + callback(null, dimDetailLines(String(chunk))); + }, + }); + + output.pipe(process.stdout); + return pretty({ ...getPrettyLoggerOptions(), destination: output }); +} export interface PayKitInternalLogger extends pino.Logger { trace: pino.Logger["trace"] & { @@ -42,13 +72,26 @@ export function getDefaultLoggerOptions( } export function getPrettyLoggerOptions(): pretty.PrettyOptions { + const ignore = + process.env.PAYKIT_CLI === "1" + ? `${PRETTY_LOG_IGNORE_FIELDS},name,traceId` + : PRETTY_LOG_IGNORE_FIELDS; + return { colorize: true, - ignore: PRETTY_LOG_IGNORE_FIELDS, - levelFirst: true, + colorizeObjects: false, + ignore, + levelFirst: false, + messageFormat: (log, messageKey) => plain(log[messageKey] ?? ""), translateTime: PRETTY_LOG_TIMESTAMP, customPrettifiers: { - time: (timestamp) => `\x1b[2m${String(timestamp)}\x1b[0m`, + actionType: dim, + duration: dim, + event: dim, + msg: (message) => String(message), + providerEventId: dim, + time: dim, + traceId: dim, }, }; } @@ -60,7 +103,7 @@ export function createPayKitLogger( const base = logging?.logger ?? (shouldUsePrettyLogs(environment) - ? pino(getDefaultLoggerOptions(logging), pretty(getPrettyLoggerOptions())) + ? pino(getDefaultLoggerOptions(logging), createPrettyStream()) : pino(getDefaultLoggerOptions(logging))); const handler: ProxyHandler = { diff --git a/packages/paykit/src/webhook/webhook.service.ts b/packages/paykit/src/webhook/webhook.service.ts index 6337c837..b9ec0064 100644 --- a/packages/paykit/src/webhook/webhook.service.ts +++ b/packages/paykit/src/webhook/webhook.service.ts @@ -136,7 +136,6 @@ async function processWebhookEvent( ctx: PayKitContext, event: AnyNormalizedWebhookEvent, providerEventId: string, - startTime: number, ): Promise { // Record the webhook outside the business transaction so failures are preserved. const shouldProcess = await beginWebhookEvent(ctx, { @@ -182,17 +181,15 @@ async function processWebhookEvent( await emitCustomerUpdated(ctx, customerId); } - const duration = Date.now() - startTime; - ctx.logger.info({ event: event.name, duration }, "webhook processed"); + ctx.logger.info({ event: event.name }, "webhook processed"); await finishWebhookEvent(ctx, { providerEventId, status: "processed", }); } catch (error) { - const duration = Date.now() - startTime; const errorDetail = error instanceof Error ? (error.stack ?? error.message) : String(error); - ctx.logger.error({ event: event.name, duration, err: error }, "webhook failed"); + ctx.logger.error({ event: event.name, err: error }, "webhook failed"); await finishWebhookEvent(ctx, { error: errorDetail, @@ -208,7 +205,6 @@ export async function handleWebhook( input: HandleWebhookInput, ): Promise<{ received: true }> { return ctx.logger.trace.run("wh", async () => { - const startTime = Date.now(); const events = await ctx.provider.handleWebhook({ allowStaleSignatures: input.allowStaleSignatures, body: input.body, @@ -219,7 +215,7 @@ export async function handleWebhook( for (const [index, event] of events.entries()) { const providerEventId = getProviderEventId(event, index, parentEventId); ctx.logger.info({ event: event.name, providerEventId }, "webhook received"); - await processWebhookEvent(ctx, event, providerEventId, startTime); + await processWebhookEvent(ctx, event, providerEventId); } return { received: true }; diff --git a/packages/stripe/src/stripe-provider.ts b/packages/stripe/src/stripe-provider.ts index 51c324a3..0375f5c2 100644 --- a/packages/stripe/src/stripe-provider.ts +++ b/packages/stripe/src/stripe-provider.ts @@ -958,7 +958,7 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): } const tolerance = data.allowStaleSignatures ? Number.POSITIVE_INFINITY : undefined; - const event = client.webhooks.constructEvent( + const event = await client.webhooks.constructEventAsync( data.body, signature, options.webhookSecret,