diff --git a/packages/mcp-client/src/Agent.ts b/packages/mcp-client/src/Agent.ts index 823769a00e..63a9e42b14 100644 --- a/packages/mcp-client/src/Agent.ts +++ b/packages/mcp-client/src/Agent.ts @@ -46,6 +46,7 @@ const exitLoopTools = [taskCompletionTool, askQuestionTool]; export class Agent extends McpClient { private readonly servers: (ServerConfig | StdioServerParameters)[]; + public readonly prompt: string; protected messages: ChatCompletionInputMessage[]; constructor({ @@ -73,10 +74,11 @@ export class Agent extends McpClient { super(provider ? { provider, endpointUrl, model, apiKey } : { provider, endpointUrl, model, apiKey }); /// ^This shenanigan is just here to please an overzealous TS type-checker. this.servers = servers; + this.prompt = prompt ?? DEFAULT_SYSTEM_PROMPT; this.messages = [ { role: "system", - content: prompt ?? DEFAULT_SYSTEM_PROMPT, + content: this.prompt, }, ]; } @@ -86,19 +88,27 @@ export class Agent extends McpClient { } async *run( - input: string, + input: string | ChatCompletionInputMessage[], opts: { abortSignal?: AbortSignal } = {} ): AsyncGenerator { - this.messages.push({ - role: "user", - content: input, - }); + let messages: ChatCompletionInputMessage[]; + if (typeof input === "string") { + /// Use internal array of messages + this.messages.push({ + role: "user", + content: input, + }); + messages = this.messages; + } else { + /// Use the passed messages directly + messages = input; + } let numOfTurns = 0; let nextTurnShouldCallTools = true; while (true) { try { - yield* this.processSingleTurnWithTools(this.messages, { + yield* this.processSingleTurnWithTools(messages, { exitLoopTools, exitIfFirstChunkNoTool: numOfTurns > 0 && nextTurnShouldCallTools, abortSignal: opts.abortSignal, @@ -111,7 +121,7 @@ export class Agent extends McpClient { } numOfTurns++; // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const currentLast = this.messages.at(-1)!; + const currentLast = messages.at(-1)!; debug("current role", currentLast.role); if ( currentLast.role === "tool" && diff --git a/packages/tiny-agents/package.json b/packages/tiny-agents/package.json index 29a9942b1d..01312259bd 100644 --- a/packages/tiny-agents/package.json +++ b/packages/tiny-agents/package.json @@ -34,7 +34,8 @@ "prepare": "pnpm run build", "test": "vitest run", "check": "tsc", - "cli": "tsx src/cli.ts" + "cli": "tsx src/cli.ts", + "cli:watch": "tsx watch src/cli.ts" }, "files": [ "src", diff --git a/packages/tiny-agents/src/cli.ts b/packages/tiny-agents/src/cli.ts index 2a670ee8d7..2f9d503e58 100644 --- a/packages/tiny-agents/src/cli.ts +++ b/packages/tiny-agents/src/cli.ts @@ -7,6 +7,7 @@ import { version as packageVersion } from "../package.json"; import { ServerConfigSchema } from "./lib/types"; import { debug, error } from "./lib/utils"; import { mainCliLoop } from "./lib/mainCliLoop"; +import { startServer } from "./lib/webServer"; import { loadConfigFrom } from "./lib/loadConfigFrom"; const USAGE_HELP = ` @@ -104,13 +105,13 @@ async function main() { } ); - if (command === "serve") { - error(`Serve is not implemented yet, coming soon!`); - process.exit(1); + debug(agent); + await agent.loadTools(); + + if (command === "run") { + mainCliLoop(agent); } else { - debug(agent); - // main loop from mcp-client/cli.ts - await mainCliLoop(agent); + startServer(agent); } } diff --git a/packages/tiny-agents/src/example.ts b/packages/tiny-agents/src/example.ts new file mode 100644 index 0000000000..1dd044124c --- /dev/null +++ b/packages/tiny-agents/src/example.ts @@ -0,0 +1,18 @@ +import { chatCompletionStream } from "@huggingface/inference"; + +async function main() { + const endpointUrl = `http://localhost:9999/v1/chat/completions`; + // launch "tiny-agents serve" before running this + + for await (const chunk of chatCompletionStream({ + endpointUrl, + model: "", + messages: [{ role: "user", content: "What are the top 5 trending models on Hugging Face?" }], + })) { + console.log(chunk.choices[0]?.delta.content); + } +} + +if (require.main === module) { + main(); +} diff --git a/packages/tiny-agents/src/lib/mainCliLoop.ts b/packages/tiny-agents/src/lib/mainCliLoop.ts index cdc20056ce..0cd35d3aad 100644 --- a/packages/tiny-agents/src/lib/mainCliLoop.ts +++ b/packages/tiny-agents/src/lib/mainCliLoop.ts @@ -5,7 +5,8 @@ import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; import type { Agent } from "../index"; /** - * From mcp-client/cli.ts + * From mcp-client/cli.ts, + * minus the agent.loadTools() done upstream. */ export async function mainCliLoop(agent: Agent): Promise { const rl = readline.createInterface({ input: stdin, output: stdout }); @@ -40,8 +41,6 @@ export async function mainCliLoop(agent: Agent): Promise { throw err; }); - await agent.loadTools(); - stdout.write(ANSI.BLUE); stdout.write(`Agent loaded with ${agent.availableTools.length} tools:\n`); stdout.write(agent.availableTools.map((t) => `- ${t.function.name}`).join("\n")); diff --git a/packages/tiny-agents/src/lib/webServer.ts b/packages/tiny-agents/src/lib/webServer.ts new file mode 100644 index 0000000000..21cf17a53c --- /dev/null +++ b/packages/tiny-agents/src/lib/webServer.ts @@ -0,0 +1,134 @@ +import type { IncomingMessage } from "node:http"; +import { createServer, ServerResponse } from "node:http"; +import type { AddressInfo } from "node:net"; +import { z } from "zod"; +import type { Agent } from "../index"; +import { ANSI } from "./utils"; +import { stdout } from "node:process"; +import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; + +const REQUEST_ID_HEADER = "X-Request-Id"; + +const ChatCompletionInputSchema = z.object({ + messages: z.array( + z.object({ + role: z.enum(["user", "assistant"]), + content: z.string().or( + z.array( + z + .object({ + type: z.literal("text"), + text: z.string(), + }) + .or( + z.object({ + type: z.literal("image_url"), + image_url: z.object({ + url: z.string(), + }), + }) + ) + ) + ), + }) + ), + /// Only allow stream: true + stream: z.literal(true), +}); +function getJsonBody(req: IncomingMessage) { + return new Promise((resolve, reject) => { + let data = ""; + req.on("data", (chunk) => (data += chunk)); + req.on("end", () => { + try { + resolve(JSON.parse(data)); + } catch (e) { + reject(e); + } + }); + req.on("error", reject); + }); +} +class ServerResp extends ServerResponse { + error(statusCode: number, reason: string) { + this.writeHead(statusCode).end(JSON.stringify({ error: reason })); + } +} + +export function startServer(agent: Agent): void { + const server = createServer({ ServerResponse: ServerResp }, async (req, res) => { + res.setHeader(REQUEST_ID_HEADER, crypto.randomUUID()); + res.setHeader("Content-Type", "application/json"); + if (req.method === "POST" && req.url === "/v1/chat/completions") { + let body: unknown; + let requestBody: z.infer; + try { + body = await getJsonBody(req); + } catch { + return res.error(400, "Invalid JSON"); + } + try { + requestBody = ChatCompletionInputSchema.parse(body); + } catch (err) { + if (err instanceof z.ZodError) { + return res.error(400, "Invalid ChatCompletionInput body \n" + JSON.stringify(err)); + } + return res.error(400, "Invalid ChatCompletionInput body"); + } + /// Ok, from now on we will send a SSE (Server-Sent Events) response. + res.setHeaders( + new Headers({ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }) + ); + + /// Prepend the agent's prompt + const messages = [ + { + role: "system", + content: agent.prompt, + }, + ...requestBody.messages, + ]; + + for await (const chunk of agent.run(messages)) { + if ("choices" in chunk) { + res.write(`data: ${JSON.stringify(chunk)}\n\n`); + } else { + /// Tool call info + /// /!\ We format it as a regular chunk of role = "tool" + const chunkToolcallInfo = { + choices: [ + { + index: 0, + delta: { + role: "tool", + content: `Tool[${chunk.name}] ${chunk.tool_call_id}\n` + chunk.content, + }, + }, + ], + created: Math.floor(Date.now() / 1000), + id: chunk.tool_call_id, + model: "", + system_fingerprint: "", + } satisfies ChatCompletionStreamOutput; + + res.write(`data: ${JSON.stringify(chunkToolcallInfo)}\n\n`); + } + } + res.end(); + } else { + res.error(404, "Route or method not found, try POST /v1/chat/completions"); + } + }); + server.listen(process.env.PORT ? parseInt(process.env.PORT) : 9_999, () => { + stdout.write(ANSI.BLUE); + stdout.write(`Agent loaded with ${agent.availableTools.length} tools:\n`); + stdout.write(agent.availableTools.map((t) => `- ${t.function.name}`).join("\n")); + stdout.write(ANSI.RESET); + stdout.write("\n"); + console.log(ANSI.GRAY + `listening on http://localhost:${(server.address() as AddressInfo).port}` + ANSI.RESET); + }); +}