diff --git a/.github/actions/pnpm_install/action.yaml b/.github/actions/pnpm_install/action.yaml index b8dcc4cd0..3816ed546 100644 --- a/.github/actions/pnpm_install/action.yaml +++ b/.github/actions/pnpm_install/action.yaml @@ -3,7 +3,7 @@ runs: steps: - uses: pnpm/action-setup@v4 with: - version: "10.18.3" + version: "10.19.0" run_install: false - uses: actions/setup-node@v4 diff --git a/apps/desktop/package.json b/apps/desktop/package.json index 8ff057078..1695a1c1a 100644 --- a/apps/desktop/package.json +++ b/apps/desktop/package.json @@ -36,6 +36,7 @@ "@hypr/utils": "workspace:^", "@iconify-icon/react": "^3.0.1", "@lobehub/icons": "^2.43.1", + "@openrouter/ai-sdk-provider": "^1.2.0", "@orama/highlight": "^0.1.9", "@orama/orama": "^3.1.16", "@orama/plugin-qps": "^3.1.16", diff --git a/apps/desktop/src/chat/transport.ts b/apps/desktop/src/chat/transport.ts index e58a439b8..3161f3086 100644 --- a/apps/desktop/src/chat/transport.ts +++ b/apps/desktop/src/chat/transport.ts @@ -1,5 +1,5 @@ -import type { ChatRequestOptions, ChatTransport, LanguageModel, UIMessageChunk } from "ai"; -import { convertToModelMessages, smoothStream, stepCountIs, streamText } from "ai"; +import type { ChatTransport, LanguageModel } from "ai"; +import { convertToModelMessages, Experimental_Agent as Agent, stepCountIs } from "ai"; import { ToolRegistry } from "../contexts/tool"; import type { HyprUIMessage } from "./types"; @@ -7,25 +7,13 @@ import type { HyprUIMessage } from "./types"; export class CustomChatTransport implements ChatTransport { constructor(private registry: ToolRegistry, private model: LanguageModel) {} - async sendMessages( - options: - & { - chatId: string; - messages: HyprUIMessage[]; - abortSignal: AbortSignal | undefined; - } - & { trigger: "submit-message" | "regenerate-message"; messageId: string | undefined } - & ChatRequestOptions, - ): Promise> { + sendMessages: ChatTransport["sendMessages"] = async (options) => { const tools = this.registry.getForTransport(); - const result = streamText({ + const agent = new Agent({ model: this.model, - messages: convertToModelMessages(options.messages), - experimental_transform: smoothStream({ chunking: "word" }), tools, stopWhen: stepCountIs(5), - abortSignal: options.abortSignal, prepareStep: async ({ messages }) => { if (messages.length > 20) { return { messages: messages.slice(-10) }; @@ -35,6 +23,8 @@ export class CustomChatTransport implements ChatTransport { }, }); + const result = agent.stream({ messages: convertToModelMessages(options.messages) }); + return result.toUIMessageStream({ originalMessages: options.messages, messageMetadata: ({ part }) => { @@ -47,9 +37,9 @@ export class CustomChatTransport implements ChatTransport { return error instanceof Error ? error.message : String(error); }, }); - } + }; - async reconnectToStream(): Promise | null> { + reconnectToStream: ChatTransport["reconnectToStream"] = async () => { return null; - } + }; } diff --git a/apps/desktop/src/components/main/body/index.tsx b/apps/desktop/src/components/main/body/index.tsx index 4bf323bbc..ae468c703 100644 --- a/apps/desktop/src/components/main/body/index.tsx +++ b/apps/desktop/src/components/main/body/index.tsx @@ -270,12 +270,21 @@ function TabChatButton() { } export function StandardTabWrapper( - { children, afterBorder }: { children: React.ReactNode; afterBorder?: React.ReactNode }, + { + children, + afterBorder, + floatingButton, + }: { + children: React.ReactNode; + afterBorder?: React.ReactNode; + floatingButton?: React.ReactNode; + }, ) { return (
{children} + {floatingButton}
{afterBorder} diff --git a/apps/desktop/src/components/main/body/sessions/floating/generate.tsx b/apps/desktop/src/components/main/body/sessions/floating/generate.tsx index 4851e2724..58c779aeb 100644 --- a/apps/desktop/src/components/main/body/sessions/floating/generate.tsx +++ b/apps/desktop/src/components/main/body/sessions/floating/generate.tsx @@ -1,20 +1,54 @@ +import { cn } from "@hypr/utils"; import { SparklesIcon } from "lucide-react"; import { useState } from "react"; -import { cn } from "@hypr/utils"; +import { useAITask } from "../../../../../contexts/ai-task"; +import { useLanguageModel } from "../../../../../hooks/useLLMConnection"; import * as persisted from "../../../../../store/tinybase/persisted"; import { FloatingButton } from "./shared"; -export function GenerateButton() { +export function GenerateButton({ sessionId }: { sessionId: string }) { const [showTemplates, setShowTemplates] = useState(false); + const model = useLanguageModel(); + + const taskId = `${sessionId}-enhance`; + + const { generate, status } = useAITask((state) => ({ + generate: state.generate, + status: state.tasks[taskId]?.status ?? "idle", + })); const templates = persisted.UI.useResultTable(persisted.QUERIES.visibleTemplates, persisted.STORE_ID); + const rawMd = persisted.UI.useCell("sessions", sessionId, "raw_md", persisted.STORE_ID); + + const updateEnhancedMd = persisted.UI.useSetPartialRowCallback( + "sessions", + sessionId, + (input: string) => ({ enhanced_md: input }), + [], + persisted.STORE_ID, + ); + + const onRegenerate = async (_templateId: string | null) => { + if (!model) { + return; + } - const onRegenerate = (templateId: string | null) => { - console.log("Regenerate clicked:", templateId); + await generate(taskId, { + model, + taskType: "enhance", + args: { rawMd }, + onComplete: updateEnhancedMd, + }); }; + const isGenerating = status === "generating"; + + if (isGenerating) { + return null; + } + return (
- } - onMouseEnter={() => setShowTemplates(true)} - onMouseLeave={() => setShowTemplates(false)} - onClick={() => { - setShowTemplates(false); - onRegenerate(null); - }} - > - Regenerate - +
+ } + onMouseEnter={() => setShowTemplates(true)} + onMouseLeave={() => setShowTemplates(false)} + onClick={() => { + setShowTemplates(false); + onRegenerate(null); + }} + > + Regenerate + +
); } diff --git a/apps/desktop/src/components/main/body/sessions/floating/index.tsx b/apps/desktop/src/components/main/body/sessions/floating/index.tsx index 626ba99db..c7a07b690 100644 --- a/apps/desktop/src/components/main/body/sessions/floating/index.tsx +++ b/apps/desktop/src/components/main/body/sessions/floating/index.tsx @@ -19,7 +19,7 @@ export function FloatingActionButton({ tab }: { tab: Extract - + ); } else if (tab.state.editor === "transcript") { diff --git a/apps/desktop/src/components/main/body/sessions/index.tsx b/apps/desktop/src/components/main/body/sessions/index.tsx index 890634eb5..882579a46 100644 --- a/apps/desktop/src/components/main/body/sessions/index.tsx +++ b/apps/desktop/src/components/main/body/sessions/index.tsx @@ -53,15 +53,17 @@ export function TabContentNote({ tab }: { tab: Extract - }> -
+ } + floatingButton={} + > +
-
+
-
- -
- +
+
+
diff --git a/apps/desktop/src/components/main/body/sessions/note-input/enhanced.tsx b/apps/desktop/src/components/main/body/sessions/note-input/enhanced.tsx deleted file mode 100644 index 052b12c91..000000000 --- a/apps/desktop/src/components/main/body/sessions/note-input/enhanced.tsx +++ /dev/null @@ -1,35 +0,0 @@ -import NoteEditor, { type TiptapEditor } from "@hypr/tiptap/editor"; -import { forwardRef } from "react"; - -import * as persisted from "../../../../../store/tinybase/persisted"; - -export const EnhancedEditor = forwardRef<{ editor: TiptapEditor | null }, { sessionId: string }>( - ({ sessionId }, ref) => { - const value = persisted.UI.useCell("sessions", sessionId, "enhanced_md", persisted.STORE_ID); - - const handleEnhancedChange = persisted.UI.useSetPartialRowCallback( - "sessions", - sessionId, - (input: string) => ({ enhanced_md: input }), - [], - persisted.STORE_ID, - ); - - return ( - { - return []; - }, - }} - /> - ); - }, -); - -EnhancedEditor.displayName = "EnhancedEditor"; diff --git a/apps/desktop/src/components/main/body/sessions/note-input/enhanced/editor.tsx b/apps/desktop/src/components/main/body/sessions/note-input/enhanced/editor.tsx new file mode 100644 index 000000000..45b998dff --- /dev/null +++ b/apps/desktop/src/components/main/body/sessions/note-input/enhanced/editor.tsx @@ -0,0 +1,36 @@ +import { forwardRef } from "react"; + +import { TiptapEditor } from "@hypr/tiptap/editor"; +import NoteEditor from "@hypr/tiptap/editor"; +import * as persisted from "../../../../../../store/tinybase/persisted"; + +export const EnhancedEditor = forwardRef<{ editor: TiptapEditor | null }, { sessionId: string }>( + ({ sessionId }, ref) => { + const value = persisted.UI.useCell("sessions", sessionId, "enhanced_md", persisted.STORE_ID); + + const handleEnhancedChange = persisted.UI.useSetPartialRowCallback( + "sessions", + sessionId, + (input: string) => ({ enhanced_md: input }), + [], + persisted.STORE_ID, + ); + + return ( +
+ { + return []; + }, + }} + /> +
+ ); + }, +); diff --git a/apps/desktop/src/components/main/body/sessions/note-input/enhanced/index.tsx b/apps/desktop/src/components/main/body/sessions/note-input/enhanced/index.tsx new file mode 100644 index 000000000..7f4b494eb --- /dev/null +++ b/apps/desktop/src/components/main/body/sessions/note-input/enhanced/index.tsx @@ -0,0 +1,28 @@ +import { type TiptapEditor } from "@hypr/tiptap/editor"; +import { forwardRef } from "react"; + +import { useAITask } from "../../../../../../contexts/ai-task"; +import { EnhancedEditor } from "./editor"; +import { StreamingView } from "./streaming"; + +export const Enhanced = forwardRef< + { editor: TiptapEditor | null }, + { sessionId: string } +>(({ sessionId }, ref) => { + const taskId = `${sessionId}-enhance`; + + const { status, error } = useAITask((state) => ({ + status: state.tasks[taskId]?.status ?? "idle", + error: state.tasks[taskId]?.error, + })); + + if (status === "error" && error) { + return
{error.message}
; + } + + if (status === "generating") { + return ; + } + + return ; +}); diff --git a/apps/desktop/src/components/main/body/sessions/note-input/enhanced/streaming.tsx b/apps/desktop/src/components/main/body/sessions/note-input/enhanced/streaming.tsx new file mode 100644 index 000000000..aa429de27 --- /dev/null +++ b/apps/desktop/src/components/main/body/sessions/note-input/enhanced/streaming.tsx @@ -0,0 +1,90 @@ +import { Loader2Icon } from "lucide-react"; +import { motion } from "motion/react"; +import { useEffect, useRef } from "react"; +import { Streamdown } from "streamdown"; + +import { cn } from "@hypr/utils"; +import { useAITask } from "../../../../../../contexts/ai-task"; + +export function StreamingView({ sessionId }: { sessionId: string }) { + const taskId = `${sessionId}-enhance`; + + const { text, step } = useAITask((state) => ({ + text: state.tasks[taskId]?.streamedText ?? "", + step: state.tasks[taskId]?.currentStep, + })); + + const containerRef = useAutoScrollToBottom(text); + + const components = { + h2: (props: React.HTMLAttributes) => { + return

{props.children as React.ReactNode}

; + }, + ul: (props: React.HTMLAttributes) => { + return
    {props.children as React.ReactNode}
; + }, + ol: (props: React.HTMLAttributes) => { + return
    {props.children as React.ReactNode}
; + }, + li: (props: React.HTMLAttributes) => { + return
  • {props.children as React.ReactNode}
  • ; + }, + } as const; + + return ( +
    +
    + + {text} + +
    + + + Generating... ({JSON.stringify(step)}) + +
    + ); +} + +function useAutoScrollToBottom(text: string) { + const containerRef = useRef(null); + + useEffect(() => { + const container = containerRef.current; + if (!container) { + return; + } + + const scrollableParent = container.parentElement; + if (!scrollableParent) { + return; + } + + const { scrollTop, scrollHeight, clientHeight } = scrollableParent; + const isNearBottom = scrollHeight - scrollTop - clientHeight < 100; + + if (isNearBottom) { + scrollableParent.scrollTop = scrollHeight; + } + }, [text]); + + return containerRef; +} diff --git a/apps/desktop/src/components/main/body/sessions/note-input/index.tsx b/apps/desktop/src/components/main/body/sessions/note-input/index.tsx index 951fa37b8..236af01fc 100644 --- a/apps/desktop/src/components/main/body/sessions/note-input/index.tsx +++ b/apps/desktop/src/components/main/body/sessions/note-input/index.tsx @@ -1,12 +1,13 @@ -import { useRef } from "react"; +import { useEffect, useRef } from "react"; import type { TiptapEditor } from "@hypr/tiptap/editor"; import { cn } from "@hypr/utils"; +import { useAITask } from "../../../../../contexts/ai-task"; import { useListener } from "../../../../../contexts/listener"; import * as persisted from "../../../../../store/tinybase/persisted"; import { type Tab, useTabs } from "../../../../../store/zustand/tabs"; import { type EditorView } from "../../../../../store/zustand/tabs/schema"; -import { EnhancedEditor } from "./enhanced"; +import { Enhanced } from "./enhanced"; import { RawEditor } from "./raw"; import { Transcript } from "./transcript"; @@ -15,6 +16,10 @@ export function NoteInput({ tab }: { tab: Extract }) const { updateSessionTabState } = useTabs(); const editorRef = useRef<{ editor: TiptapEditor | null }>(null); + const taskId = `${tab.id}-enhance`; + + const taskStatus = useAITask((state) => state.tasks[taskId]?.status ?? "idle"); + const handleTabChange = (view: EditorView) => { updateSessionTabState(tab, { editor: view }); }; @@ -23,6 +28,12 @@ export function NoteInput({ tab }: { tab: Extract }) editorRef.current?.editor?.commands.focus(); }; + useEffect(() => { + if (taskStatus === "generating" && tab.state.editor !== "enhanced") { + updateSessionTabState(tab, { editor: "enhanced" }); + } + }, [taskStatus, tab.state.editor, updateSessionTabState, tab]); + const sessionId = tab.id; const currentTab = tab.state.editor ?? "raw"; @@ -30,7 +41,7 @@ export function NoteInput({ tab }: { tab: Extract })
    - {currentTab === "enhanced" && } + {currentTab === "enhanced" && } {currentTab === "raw" && } {currentTab === "transcript" && }
    @@ -75,19 +86,14 @@ function Header( function useEditorTabs({ sessionId }: { sessionId: string }): EditorView[] { const status = useListener((state) => state.status); - const enhanced = !!persisted.UI.useCell("sessions", sessionId, "enhanced_md", persisted.STORE_ID); const hasTranscript = useHasTranscript(sessionId); if (status === "running_active") { return ["raw", "transcript"]; } - if (enhanced) { - return ["enhanced", "raw", "transcript"]; - } - if (hasTranscript) { - return ["raw", "transcript"]; + return ["enhanced", "raw", "transcript"]; } return ["raw"]; diff --git a/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx b/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx index 6ef3ad96f..9e8283c97 100644 --- a/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx +++ b/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx @@ -192,6 +192,7 @@ export const openaiCompatibleListModels = async (baseUrl: string, apiKey: string }[]; }; const removeNonToolModels = data + .filter((model) => !["audio", "image", "code"].some((keyword) => model.id.includes(keyword))) .filter((model) => { if ( Array.isArray(model.architecture?.input_modalities) @@ -204,7 +205,7 @@ export const openaiCompatibleListModels = async (baseUrl: string, apiKey: string return true; } - return model.supported_parameters.includes("tools"); + return ["tools", "tool_choice"].every((parameter) => model.supported_parameters?.includes(parameter)); }); return removeNonToolModels.map((model) => model.id); diff --git a/apps/desktop/src/contexts/ai-task/enhancing.ts b/apps/desktop/src/contexts/ai-task/enhancing.ts new file mode 100644 index 000000000..46dcae775 --- /dev/null +++ b/apps/desktop/src/contexts/ai-task/enhancing.ts @@ -0,0 +1,70 @@ +import { Experimental_Agent as Agent, generateText, type LanguageModel, stepCountIs, Tool, tool } from "ai"; +import { z } from "zod"; + +export function createEnhancingAgent(model: LanguageModel) { + const system = ` + You are an expert at creating structured, comprehensive meeting summaries. + + Format requirements: + - Do not use h1, start with h2(##) + - Use h2 and h3 headers for sections (no deeper than h3) + - Each section should have at least 5 detailed bullet points + + Workflow: + 1. User provides raw meeting content. + 2. You analyze the content and decide the sections to use. (Using analyzeStructure) + 3. You generate a well-formatted markdown summary, following the format requirements. + + IMPORTANT: Your final output MUST be ONLY the markdown summary itself. + Do NOT include any explanations, commentary, or meta-discussion. + Do NOT say things like "Here's the summary" or "I've analyzed". +`.trim(); + + const tools: Record = { + analyzeStructure: tool({ + description: "Analyze raw meeting content to identify key themes, topics, and overall structure", + inputSchema: z.object({ + max_num_sections: z + .number() + .describe(`Maximum number of sections to generate. + Based on the content, decide the number of sections to generate.`), + }), + execute: async ({ max_num_sections }, { messages }) => { + const lastMessage = messages[messages.length - 1]; + const input = typeof lastMessage.content === "string" + ? lastMessage.content + : lastMessage.content.map((part) => part.type === "text" ? part.text : "").join("\n"); + + const { content: output } = await generateText({ + model, + prompt: ` + Analyze this meeting content and suggest appropriate section headings for a comprehensive summary. + The sections should cover the main themes and topics discussed. + Generate around ${max_num_sections} sections based on the content depth. + Give me in bullet points. + + Content: ${input}`, + }); + + return output; + }, + }), + }; + + return new Agent({ + model, + stopWhen: stepCountIs(10), + system, + tools, + prepareStep: async ({ stepNumber }) => { + console.log("prepareStep", stepNumber); + if (stepNumber === 0) { + return { + toolChoice: { type: "tool", toolName: "analyzeStructure" }, + }; + } + + return { toolChoice: "none" }; + }, + }); +} diff --git a/apps/desktop/src/contexts/ai-task/index.tsx b/apps/desktop/src/contexts/ai-task/index.tsx new file mode 100644 index 000000000..6e65074dd --- /dev/null +++ b/apps/desktop/src/contexts/ai-task/index.tsx @@ -0,0 +1,42 @@ +import React, { createContext, useContext, useRef } from "react"; +import { useStore } from "zustand"; +import { useShallow } from "zustand/shallow"; + +import { type AITaskStore, createAITaskStore } from "../../store/zustand/ai-task"; + +const AITaskContext = createContext(null); + +export const AITaskProvider = ({ + children, + store, +}: { + children: React.ReactNode; + store: AITaskStore; +}) => { + const storeRef = useRef(null); + if (!storeRef.current) { + storeRef.current = store; + } + + return ( + + {children} + + ); +}; + +export const useAITask = ( + selector: Parameters< + typeof useStore, T> + >[1], +) => { + const store = useContext(AITaskContext); + + if (!store) { + throw new Error( + "'useAITask' must be used within a 'AITaskProvider'", + ); + } + + return useStore(store, useShallow(selector)); +}; diff --git a/apps/desktop/src/hooks/useLLMConnection.ts b/apps/desktop/src/hooks/useLLMConnection.ts index 3e39014d2..8b89b3fef 100644 --- a/apps/desktop/src/hooks/useLLMConnection.ts +++ b/apps/desktop/src/hooks/useLLMConnection.ts @@ -1,5 +1,7 @@ import { createAnthropic } from "@ai-sdk/anthropic"; +import { createOpenAI } from "@ai-sdk/openai"; import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; +import { createOpenRouter } from "@openrouter/ai-sdk-provider"; import type { LanguageModel } from "ai"; import { useMemo } from "react"; @@ -22,6 +24,22 @@ export const useLanguageModel = (): LanguageModel | null => { return anthropicProvider(connection.modelId); } + if (connection.providerId === "openrouter") { + const openRouterProvider = createOpenRouter({ + apiKey: connection.apiKey, + }); + + return openRouterProvider(connection.modelId); + } + + if (connection.providerId === "openai") { + const openAIProvider = createOpenAI({ + apiKey: connection.apiKey, + }); + + return openAIProvider(connection.modelId); + } + const openAICompatibleProvider = createOpenAICompatible({ name: connection.providerId, baseURL: connection.baseUrl, diff --git a/apps/desktop/src/main.tsx b/apps/desktop/src/main.tsx index 657ac9c76..58c6786df 100644 --- a/apps/desktop/src/main.tsx +++ b/apps/desktop/src/main.tsx @@ -18,9 +18,11 @@ import { } from "./store/tinybase/persisted"; import { routeTree } from "./routeTree.gen"; +import { createAITaskStore } from "./store/zustand/ai-task"; import { createListenerStore } from "./store/zustand/listener"; const listenerStore = createListenerStore(); +const aiTaskStore = createAITaskStore(); const queryClient = new QueryClient(); const router = createRouter({ routeTree, context: undefined }); @@ -48,6 +50,7 @@ function App() { persistedStore, internalStore, listenerStore, + aiTaskStore, }} /> ); diff --git a/apps/desktop/src/routes/app/main/_layout.tsx b/apps/desktop/src/routes/app/main/_layout.tsx index a33a542ed..b17fa0dff 100644 --- a/apps/desktop/src/routes/app/main/_layout.tsx +++ b/apps/desktop/src/routes/app/main/_layout.tsx @@ -2,6 +2,7 @@ import { createFileRoute, Outlet, useRouteContext } from "@tanstack/react-router import { useCallback, useEffect } from "react"; import { toolFactories } from "../../../chat/tools"; +import { AITaskProvider } from "../../../contexts/ai-task"; import { useSearchEngine } from "../../../contexts/search/engine"; import { SearchEngineProvider } from "../../../contexts/search/engine"; import { SearchUIProvider } from "../../../contexts/search/ui"; @@ -16,7 +17,7 @@ export const Route = createFileRoute("/app/main/_layout")({ }); function Component() { - const { persistedStore, internalStore } = useRouteContext({ from: "__root__" }); + const { persistedStore, internalStore, aiTaskStore } = useRouteContext({ from: "__root__" }); const { registerOnClose, registerOnEmpty, currentTab, openNew, invalidateResource } = useTabs(); const createDefaultSession = useCallback(() => { @@ -52,13 +53,19 @@ function Component() { registerOnEmpty(createDefaultSession); }, [createDefaultSession, registerOnEmpty]); + if (!aiTaskStore) { + return null; + } + return ( - - + + + + diff --git a/apps/desktop/src/store/zustand/ai-task/index.ts b/apps/desktop/src/store/zustand/ai-task/index.ts new file mode 100644 index 000000000..fae581f25 --- /dev/null +++ b/apps/desktop/src/store/zustand/ai-task/index.ts @@ -0,0 +1,15 @@ +import { createStore } from "zustand"; + +import { createTasksSlice, type TasksActions, type TasksState } from "./tasks"; + +type State = TasksState; +type Actions = TasksActions; +type Store = State & Actions; + +export type AITaskStore = ReturnType; + +export const createAITaskStore = () => { + return createStore((set, get) => ({ + ...createTasksSlice(set, get), + })); +}; diff --git a/apps/desktop/src/store/zustand/ai-task/shared/transform_impl.test.ts b/apps/desktop/src/store/zustand/ai-task/shared/transform_impl.test.ts new file mode 100644 index 000000000..59cd32a6b --- /dev/null +++ b/apps/desktop/src/store/zustand/ai-task/shared/transform_impl.test.ts @@ -0,0 +1,279 @@ +import type { TextStreamPart, ToolSet } from "ai"; +import { beforeEach, describe, expect, it } from "vitest"; + +import { trimBeforeMarker } from "./transform_impl"; + +function convertArrayToReadableStream(values: T[]): ReadableStream { + return new ReadableStream({ + start(controller) { + for (const value of values) { + controller.enqueue(value); + } + controller.close(); + }, + }); +} + +describe("trimBeforeMarker", () => { + let events: any[] = []; + + beforeEach(() => { + events = []; + }); + + async function consumeStream(stream: ReadableStream) { + const reader = stream.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + events.push(value); + } + } + + it("should trim text before marker (##)", async () => { + const stream = convertArrayToReadableStream>([ + { type: "text-start", id: "1" }, + { text: "ok. I will give you that! ", type: "text-delta", id: "1" }, + { text: "## Header", type: "text-delta", id: "1" }, + { text: " content", type: "text-delta", id: "1" }, + { type: "text-end", id: "1" }, + ]).pipeThrough(trimBeforeMarker("##")({ tools: {}, stopStream: () => {} })); + + await consumeStream(stream); + + expect(events).toMatchInlineSnapshot(` + [ + { + "id": "1", + "type": "text-start", + }, + { + "id": "1", + "text": "## Header", + "type": "text-delta", + }, + { + "id": "1", + "text": " content", + "type": "text-delta", + }, + { + "id": "1", + "type": "text-end", + }, + ] + `); + }); + + it("should handle marker split across chunks", async () => { + const stream = convertArrayToReadableStream>([ + { type: "text-start", id: "1" }, + { text: "ok. I will give you that! #", type: "text-delta", id: "1" }, + { text: "# Header", type: "text-delta", id: "1" }, + { text: " content", type: "text-delta", id: "1" }, + { type: "text-end", id: "1" }, + ]).pipeThrough(trimBeforeMarker("##")({ tools: {}, stopStream: () => {} })); + + await consumeStream(stream); + + expect(events).toMatchInlineSnapshot(` + [ + { + "id": "1", + "type": "text-start", + }, + { + "id": "1", + "text": "## Header", + "type": "text-delta", + }, + { + "id": "1", + "text": " content", + "type": "text-delta", + }, + { + "id": "1", + "type": "text-end", + }, + ] + `); + }); + + it("should trim before single # when looking for ##", async () => { + const stream = convertArrayToReadableStream>([ + { type: "text-start", id: "1" }, + { text: "# Wrong header\n", type: "text-delta", id: "1" }, + { text: "## Correct header", type: "text-delta", id: "1" }, + { type: "text-end", id: "1" }, + ]).pipeThrough(trimBeforeMarker("##")({ tools: {}, stopStream: () => {} })); + + await consumeStream(stream); + + expect(events).toMatchInlineSnapshot(` + [ + { + "id": "1", + "type": "text-start", + }, + { + "id": "1", + "text": "## Correct header", + "type": "text-delta", + }, + { + "id": "1", + "type": "text-end", + }, + ] + `); + }); + + it("should handle marker at the very start", async () => { + const stream = convertArrayToReadableStream>([ + { type: "text-start", id: "1" }, + { text: "## Header", type: "text-delta", id: "1" }, + { text: " content", type: "text-delta", id: "1" }, + { type: "text-end", id: "1" }, + ]).pipeThrough(trimBeforeMarker("##")({ tools: {}, stopStream: () => {} })); + + await consumeStream(stream); + + expect(events).toMatchInlineSnapshot(` + [ + { + "id": "1", + "type": "text-start", + }, + { + "id": "1", + "text": "## Header", + "type": "text-delta", + }, + { + "id": "1", + "text": " content", + "type": "text-delta", + }, + { + "id": "1", + "type": "text-end", + }, + ] + `); + }); + + it("should send all buffered chunks if marker is never found", async () => { + const stream = convertArrayToReadableStream>([ + { type: "text-start", id: "1" }, + { text: "No marker here", type: "text-delta", id: "1" }, + { text: " at all", type: "text-delta", id: "1" }, + { type: "text-end", id: "1" }, + ]).pipeThrough(trimBeforeMarker("##")({ tools: {}, stopStream: () => {} })); + + await consumeStream(stream); + + expect(events).toMatchInlineSnapshot(` + [ + { + "id": "1", + "type": "text-start", + }, + { + "id": "1", + "text": "No marker here", + "type": "text-delta", + }, + { + "id": "1", + "text": " at all", + "type": "text-delta", + }, + { + "id": "1", + "type": "text-end", + }, + ] + `); + }); + + it("should handle non-text-delta chunks before marker is found", async () => { + const stream = convertArrayToReadableStream>([ + { type: "text-start", id: "1" }, + { text: "prefix ", type: "text-delta", id: "1" }, + { + type: "tool-call", + toolCallId: "1", + toolName: "test", + input: {}, + }, + { text: "## Header", type: "text-delta", id: "1" }, + { type: "text-end", id: "1" }, + ]).pipeThrough(trimBeforeMarker("##")({ tools: {}, stopStream: () => {} })); + + await consumeStream(stream); + + expect(events).toMatchInlineSnapshot(` + [ + { + "input": {}, + "toolCallId": "1", + "toolName": "test", + "type": "tool-call", + }, + { + "id": "1", + "type": "text-start", + }, + { + "id": "1", + "text": "## Header", + "type": "text-delta", + }, + { + "id": "1", + "type": "text-end", + }, + ] + `); + }); + + it("should work with custom markers", async () => { + const stream = convertArrayToReadableStream>([ + { type: "text-start", id: "1" }, + { text: "Some intro text... ", type: "text-delta", id: "1" }, + { text: "START:", type: "text-delta", id: "1" }, + { text: " actual content", type: "text-delta", id: "1" }, + { type: "text-end", id: "1" }, + ]).pipeThrough( + trimBeforeMarker("START:")({ tools: {}, stopStream: () => {} }), + ); + + await consumeStream(stream); + + expect(events).toMatchInlineSnapshot(` + [ + { + "id": "1", + "type": "text-start", + }, + { + "id": "1", + "text": "START:", + "type": "text-delta", + }, + { + "id": "1", + "text": " actual content", + "type": "text-delta", + }, + { + "id": "1", + "type": "text-end", + }, + ] + `); + }); +}); diff --git a/apps/desktop/src/store/zustand/ai-task/shared/transform_impl.ts b/apps/desktop/src/store/zustand/ai-task/shared/transform_impl.ts new file mode 100644 index 000000000..561da65f8 --- /dev/null +++ b/apps/desktop/src/store/zustand/ai-task/shared/transform_impl.ts @@ -0,0 +1,70 @@ +import type { TextStreamPart, ToolSet } from "ai"; +import type { StreamTransform } from "./transform_infra"; + +export function trimBeforeMarker( + marker: string, +): StreamTransform { + return () => { + let fullText = ""; + let hasFoundMarker = false; + let bufferedChunks: TextStreamPart[] = []; + + return new TransformStream, TextStreamPart>({ + transform(chunk, controller) { + if ( + chunk.type === "tool-call" + || chunk.type === "tool-result" + || chunk.type === "tool-error" + || chunk.type === "tool-input-start" + || chunk.type === "tool-input-delta" + || chunk.type === "tool-input-end" + || chunk.type === "start-step" + || chunk.type === "finish-step" + ) { + controller.enqueue(chunk); + return; + } + + if (!hasFoundMarker) { + if (chunk.type === "text-delta") { + fullText += chunk.text; + } + + bufferedChunks.push(chunk); + + if (chunk.type === "text-delta") { + const markerIndex = fullText.indexOf(marker); + if (markerIndex !== -1) { + hasFoundMarker = true; + const trimmedText = fullText.substring(markerIndex); + + for (const buffered of bufferedChunks) { + if (buffered.type === "text-delta") { + controller.enqueue({ + ...buffered, + text: trimmedText, + }); + break; + } else { + controller.enqueue(buffered); + } + } + + bufferedChunks = []; + } + } + } else { + controller.enqueue(chunk); + } + }, + + flush(controller) { + if (!hasFoundMarker) { + for (const chunk of bufferedChunks) { + controller.enqueue(chunk); + } + } + }, + }); + }; +} diff --git a/apps/desktop/src/store/zustand/ai-task/shared/transform_infra.ts b/apps/desktop/src/store/zustand/ai-task/shared/transform_infra.ts new file mode 100644 index 000000000..f80efe3c8 --- /dev/null +++ b/apps/desktop/src/store/zustand/ai-task/shared/transform_infra.ts @@ -0,0 +1,68 @@ +// https://github.com/vercel/ai/blob/282f062922cb59167dd3a11e3af67cfa0b75f317/packages/ai/src/generate-text/stream-text.ts + +import type { TextStreamPart, ToolSet } from "ai"; + +export type StreamTransform = (options: { + tools: TOOLS; + stopStream: () => void; +}) => TransformStream, TextStreamPart>; + +export async function* applyTransforms( + stream: AsyncIterable>, + transforms: StreamTransform[], + options: { + tools?: TOOLS; + stopStream?: () => void; + } = {}, +): AsyncIterable> { + if (transforms.length === 0) { + return yield* stream; + } + + const stopStream = options.stopStream ?? (() => {}); + const tools = options.tools ?? ({} as TOOLS); + + let readableStream = streamToReadable(stream); + + for (const transform of transforms) { + readableStream = readableStream.pipeThrough( + transform({ tools, stopStream }), + ); + } + + yield* streamToAsyncIterable(readableStream); +} + +function streamToReadable( + stream: AsyncIterable, +): ReadableStream { + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of stream) { + controller.enqueue(chunk); + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); +} + +async function* streamToAsyncIterable( + stream: ReadableStream, +): AsyncIterable { + const reader = stream.getReader(); + try { + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + yield value; + } + } finally { + reader.releaseLock(); + } +} diff --git a/apps/desktop/src/store/zustand/ai-task/task-configs.ts b/apps/desktop/src/store/zustand/ai-task/task-configs.ts new file mode 100644 index 000000000..651578523 --- /dev/null +++ b/apps/desktop/src/store/zustand/ai-task/task-configs.ts @@ -0,0 +1,25 @@ +import { type Experimental_Agent as Agent, type LanguageModel, smoothStream } from "ai"; + +import { createEnhancingAgent } from "../../../contexts/ai-task/enhancing"; +import { trimBeforeMarker } from "./shared/transform_impl"; + +export type TaskType = "enhance"; + +export interface TaskConfig { + getPrompt: (args?: Record) => string; + getAgent?: (model: LanguageModel) => Agent; + transforms?: any[]; +} + +export const TASK_CONFIGS: Record = { + enhance: { + getPrompt: () => { + return "Generate some random meeting summary, following markdown format. Start with h2 header(##) and no more than h3. Each header should have more than 5 points, bullet points."; + }, + getAgent: (model) => createEnhancingAgent(model), + transforms: [ + trimBeforeMarker("##"), + smoothStream({ delayInMs: 100, chunking: "line" }), + ], + }, +}; diff --git a/apps/desktop/src/store/zustand/ai-task/tasks.ts b/apps/desktop/src/store/zustand/ai-task/tasks.ts new file mode 100644 index 000000000..276940f91 --- /dev/null +++ b/apps/desktop/src/store/zustand/ai-task/tasks.ts @@ -0,0 +1,205 @@ +import { Experimental_Agent as Agent, type LanguageModel, stepCountIs } from "ai"; +import { create as mutate } from "mutative"; +import type { StoreApi } from "zustand"; + +import { applyTransforms } from "./shared/transform_infra"; +import { TASK_CONFIGS, type TaskType } from "./task-configs"; + +export type TasksState = { + tasks: Record; +}; + +export type TasksActions = { + generate: ( + taskId: string, + config: { + model: LanguageModel; + taskType: TaskType; + args?: Record; + onComplete?: (text: string) => void; + }, + ) => Promise; + cancel: (taskId: string) => void; + getState: (taskId: string) => TaskState; +}; + +type StepInfo = + | { type: "generating" } + | { type: "tool-call" | "tool-result"; toolName: string }; + +type TaskState = { + status: "idle" | "generating" | "success" | "error"; + streamedText: string; + error?: Error; + abortController: AbortController | null; + currentStep?: StepInfo; +}; + +const initialState: TasksState = { + tasks: {}, +}; + +export const createTasksSlice = ( + set: StoreApi["setState"], + get: StoreApi["getState"], +): TasksState & TasksActions => ({ + ...initialState, + getState: (taskId: string) => { + const state = get().tasks[taskId]; + return { + status: state?.status ?? "idle", + streamedText: state?.streamedText ?? "", + error: state?.error, + abortController: state?.abortController ?? null, + currentStep: state?.currentStep, + }; + }, + cancel: (taskId: string) => { + const state = get().tasks[taskId]; + if (state?.abortController) { + state.abortController.abort(); + } + }, + generate: async ( + taskId: string, + config: { + model: LanguageModel; + taskType: TaskType; + args?: Record; + onComplete?: (text: string) => void; + }, + ) => { + const abortController = new AbortController(); + const taskConfig = TASK_CONFIGS[config.taskType]; + const prompt = taskConfig.getPrompt(config.args); + + set((state) => + mutate(state, (draft) => { + draft.tasks[taskId] = { + status: "generating", + streamedText: "", + error: undefined, + abortController, + currentStep: undefined, + }; + }) + ); + + try { + const agent = getAgentForTask(config.taskType, config.model); + const result = agent.stream({ prompt }); + + let fullText = ""; + + const checkAbort = () => { + if (abortController.signal.aborted) { + const error = new Error("Aborted"); + error.name = "AbortError"; + throw error; + } + }; + + const transforms = taskConfig.transforms ?? []; + const transformedStream = applyTransforms(result.fullStream, transforms, { + tools: result.toolCalls, + stopStream: () => abortController.abort(), + }); + + for await (const chunk of transformedStream) { + checkAbort(); + + if (chunk.type === "text-delta") { + fullText += chunk.text; + + set((state) => + mutate(state, (draft) => { + const currentState = draft.tasks[taskId]; + if (currentState) { + currentState.streamedText = fullText; + currentState.currentStep = { type: "generating" }; + } + }) + ); + } else if (chunk.type === "tool-call") { + set((state) => + mutate(state, (draft) => { + const currentState = draft.tasks[taskId]; + if (currentState) { + currentState.currentStep = { + type: "tool-call", + toolName: chunk.toolName, + }; + } + }) + ); + } else if (chunk.type === "tool-result") { + set((state) => + mutate(state, (draft) => { + const currentState = draft.tasks[taskId]; + if (currentState) { + currentState.currentStep = { + type: "tool-result", + toolName: chunk.toolName, + }; + } + }) + ); + } + } + + set((state) => + mutate(state, (draft) => { + draft.tasks[taskId] = { + status: "success", + streamedText: fullText, + error: undefined, + abortController: null, + currentStep: undefined, + }; + }) + ); + + config.onComplete?.(fullText); + } catch (err) { + if (err instanceof Error && (err.name === "AbortError" || err.message === "Aborted")) { + set((state) => + mutate(state, (draft) => { + draft.tasks[taskId] = { + status: "idle", + streamedText: "", + error: undefined, + abortController: null, + currentStep: undefined, + }; + }) + ); + } else { + const error = err instanceof Error ? err : new Error(String(err)); + set((state) => + mutate(state, (draft) => { + draft.tasks[taskId] = { + status: "error", + streamedText: "", + error, + abortController: null, + currentStep: undefined, + }; + }) + ); + } + } + }, +}); + +function getAgentForTask(taskType: TaskType, model: LanguageModel) { + const taskConfig = TASK_CONFIGS[taskType]; + + if (taskConfig.getAgent) { + return taskConfig.getAgent(model); + } + + return new Agent({ + model, + stopWhen: stepCountIs(10), + }); +} diff --git a/apps/desktop/src/types/index.ts b/apps/desktop/src/types/index.ts index 56a924a04..2c7b23ecd 100644 --- a/apps/desktop/src/types/index.ts +++ b/apps/desktop/src/types/index.ts @@ -1,10 +1,12 @@ import { type Store as InternalStore } from "../store/tinybase/internal"; import { type Store as PersistedStore } from "../store/tinybase/persisted"; +import type { AITaskStore } from "../store/zustand/ai-task"; import type { ListenerStore } from "../store/zustand/listener"; export type Context = { persistedStore: PersistedStore; internalStore: InternalStore; listenerStore: ListenerStore; + aiTaskStore: AITaskStore; }; diff --git a/package.json b/package.json index 95f178301..257c3b0c5 100644 --- a/package.json +++ b/package.json @@ -1,8 +1,8 @@ { - "packageManager": "pnpm@10.18.3", + "packageManager": "pnpm@10.19.0", "devDependencies": { - "turbo": "^2.5.8", - "esbuild": "0.25.11" + "esbuild": "0.25.11", + "turbo": "^2.5.8" }, "pnpm": { "overrides": { diff --git a/packages/tiptap/package.json b/packages/tiptap/package.json index 4c94ce19a..982c44dbb 100644 --- a/packages/tiptap/package.json +++ b/packages/tiptap/package.json @@ -36,6 +36,7 @@ "@tiptap/extension-text": "^3.7.2", "@tiptap/extension-typography": "^3.7.2", "@tiptap/extension-underline": "^3.7.2", + "@tiptap/markdown": "^3.7.2", "@tiptap/pm": "^3.7.2", "@tiptap/react": "^3.7.2", "@tiptap/starter-kit": "^3.7.2", diff --git a/packages/tiptap/src/editor/index.tsx b/packages/tiptap/src/editor/index.tsx index b0785b1de..a90f4fe08 100644 --- a/packages/tiptap/src/editor/index.tsx +++ b/packages/tiptap/src/editor/index.tsx @@ -1,6 +1,7 @@ import "../styles/tiptap.css"; import "../styles/mention.css"; +import { Markdown } from "@tiptap/markdown"; import { type Editor as TiptapEditor, EditorContent, type HTMLContent, useEditor } from "@tiptap/react"; import { forwardRef, useEffect, useRef } from "react"; @@ -45,9 +46,11 @@ const Editor = forwardRef<{ editor: TiptapEditor | null }, EditorProps>( extensions: [ ...shared.getExtensions(placeholderComponent), mention(mentionConfig), + Markdown, ], editable, - content: initialContent || "

    ", + contentType: "markdown", + content: initialContent || "", onCreate: ({ editor }) => { editor.view.dom.setAttribute("spellcheck", "false"); editor.view.dom.setAttribute("autocomplete", "off"); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c8ae917c1..89f2de07e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -86,6 +86,9 @@ importers: '@lobehub/icons': specifier: ^2.43.1 version: 2.43.1(@babel/core@7.28.4)(@types/mdast@4.0.4)(@types/react@19.2.2)(antd@5.27.5(date-fns@4.1.0)(react-dom@19.2.0(react@19.2.0))(react@19.2.0))(framer-motion@11.18.2(react-dom@19.2.0(react@19.2.0))(react@19.2.0))(micromark-util-types@2.0.2)(micromark@4.0.2)(react-dom@19.2.0(react@19.2.0))(react@19.2.0) + '@openrouter/ai-sdk-provider': + specifier: ^1.2.0 + version: 1.2.0(ai@5.0.76(zod@4.1.12))(zod@4.1.12) '@orama/highlight': specifier: ^0.1.9 version: 0.1.9 @@ -514,6 +517,9 @@ importers: '@tiptap/extension-underline': specifier: ^3.7.2 version: 3.7.2(@tiptap/core@3.7.2(@tiptap/pm@3.7.2)) + '@tiptap/markdown': + specifier: ^3.7.2 + version: 3.7.2(@tiptap/core@3.7.2(@tiptap/pm@3.7.2))(@tiptap/pm@3.7.2) '@tiptap/pm': specifier: ^3.7.2 version: 3.7.2 @@ -1977,6 +1983,13 @@ packages: resolution: {integrity: sha512-T8TbSnGsxo6TDBJx/Sgv/BlVJL3tshxZP7Aq5R1mSnM5OcHY2dQaxLMu2+E8u3gN0MLOzdjurqN4ZRVuzQycOQ==} engines: {node: '>=8.0'} + '@openrouter/ai-sdk-provider@1.2.0': + resolution: {integrity: sha512-stuIwq7Yb7DNmk3GuCtz+oS3nZOY4TXEV3V5KsknDGQN7Fpu3KRMQVWRc1J073xKdf0FC9EHOctSyzsACmp5Ag==} + engines: {node: '>=18'} + peerDependencies: + ai: ^5.0.0 + zod: ^3.24.1 || ^v4 + '@opentelemetry/api-logs@0.204.0': resolution: {integrity: sha512-DqxY8yoAaiBPivoJD4UtgrMS8gEmzZ5lnaxzPojzLVHBGqPxgWm4zcuvcUHZiqQ6kRX2Klel2r9y8cA2HAtqpw==} engines: {node: '>=8.0.0'} @@ -3937,6 +3950,12 @@ packages: '@tiptap/core': ^3.7.2 '@tiptap/pm': ^3.7.2 + '@tiptap/markdown@3.7.2': + resolution: {integrity: sha512-0cdCYYHdBDXcwjZsTOSySbdHQuHZct6nxvcp4dSVpP25kbZL3ONSJvLY5Nsy3rkXlmhk9qbyFwsexGiSIdFy8Q==} + peerDependencies: + '@tiptap/core': ^3.7.2 + '@tiptap/pm': ^3.7.2 + '@tiptap/pm@3.7.2': resolution: {integrity: sha512-i2fvXDapwo/TWfHM6STYEbkYyF3qyfN6KEBKPrleX/Z80G5bLxom0gB79TsjLNxTLi6mdf0vTHgAcXMG1avc2g==} @@ -10288,6 +10307,11 @@ snapshots: '@oozcitak/util@8.3.8': {} + '@openrouter/ai-sdk-provider@1.2.0(ai@5.0.76(zod@4.1.12))(zod@4.1.12)': + dependencies: + ai: 5.0.76(zod@4.1.12) + zod: 4.1.12 + '@opentelemetry/api-logs@0.204.0': dependencies: '@opentelemetry/api': 1.9.0 @@ -12657,6 +12681,12 @@ snapshots: '@tiptap/core': 3.7.2(@tiptap/pm@3.7.2) '@tiptap/pm': 3.7.2 + '@tiptap/markdown@3.7.2(@tiptap/core@3.7.2(@tiptap/pm@3.7.2))(@tiptap/pm@3.7.2)': + dependencies: + '@tiptap/core': 3.7.2(@tiptap/pm@3.7.2) + '@tiptap/pm': 3.7.2 + marked: 16.4.1 + '@tiptap/pm@3.7.2': dependencies: prosemirror-changeset: 2.3.1