diff --git a/Cargo.lock b/Cargo.lock index 7c8a52b11..2525043f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4634,6 +4634,8 @@ dependencies = [ "gbnf-validator", "indoc", "insta", + "serde", + "specta", ] [[package]] @@ -14375,6 +14377,7 @@ dependencies = [ name = "tauri-plugin-template" version = "0.1.0" dependencies = [ + "gbnf", "insta", "serde_json", "specta", diff --git a/apps/desktop/src/components/editor-area/index.tsx b/apps/desktop/src/components/editor-area/index.tsx index e5a6909b4..8fa50a771 100644 --- a/apps/desktop/src/components/editor-area/index.tsx +++ b/apps/desktop/src/components/editor-area/index.tsx @@ -12,7 +12,7 @@ import { commands as analyticsCommands } from "@hypr/plugin-analytics"; import { commands as connectorCommands } from "@hypr/plugin-connector"; import { commands as dbCommands } from "@hypr/plugin-db"; import { commands as miscCommands } from "@hypr/plugin-misc"; -import { commands as templateCommands } from "@hypr/plugin-template"; +import { commands as templateCommands, type Grammar } from "@hypr/plugin-template"; import Editor, { type TiptapEditor } from "@hypr/tiptap/editor"; import Renderer from "@hypr/tiptap/renderer"; import { extractHashtags } from "@hypr/tiptap/shared"; @@ -282,48 +282,28 @@ export function useEnhanceMutation({ return; } - // Get current config for default template const config = await dbCommands.getConfig(); - // Use provided templateId or fall back to config - const effectiveTemplateId = templateId !== undefined - ? templateId - : config.general?.selected_template_id; + const getTemplate = async () => { + const effectiveTemplateId = templateId !== undefined + ? templateId + : config.general?.selected_template_id; - let templateInfo = ""; - let customGrammar: string | null = null; + if (!effectiveTemplateId) { + return null; + } - if (effectiveTemplateId) { const templates = await dbCommands.listTemplates(); - const selectedTemplate = templates.find(t => t.id === effectiveTemplateId); - - if (selectedTemplate) { - if (selectedTemplate.sections && selectedTemplate.sections.length > 0) { - customGrammar = generateCustomGBNF(selectedTemplate.sections); - } - - templateInfo = ` -SELECTED TEMPLATE: -Template Title: ${selectedTemplate.title || "Untitled"} -Template Description: ${selectedTemplate.description || "No description"} + return templates.find(t => t.id === effectiveTemplateId) || null; + }; -Sections:`; - - selectedTemplate.sections?.forEach((section, index) => { - templateInfo += ` - ${index + 1}. ${section.title || "Untitled Section"} - └─ ${section.description || "No description"}`; - }); - } - } else { - console.log("Using default template (no custom template selected)"); - } + const selectedTemplate = await getTemplate(); const participants = await dbCommands.sessionListParticipants(sessionId); const systemMessage = await templateCommands.render( "enhance.system", - { config, type, templateInfo }, + { config, type, templateInfo: selectedTemplate }, ); const userMessage = await templateCommands.render( @@ -375,14 +355,12 @@ Sections:`; ...(freshIsLocalLlm && { providerOptions: { [localProviderName]: { - metadata: customGrammar - ? { - grammar: "custom", - customGrammar: customGrammar, - } - : { - grammar: "enhance", - }, + metadata: { + grammar: { + task: "enhance", + sections: selectedTemplate?.sections.map(s => s.title) || null, + } satisfies Grammar, + }, }, }, }), @@ -478,7 +456,9 @@ function useGenerateTitleMutation({ sessionId }: { sessionId: string }) { providerOptions: { [localProviderName]: { metadata: { - grammar: "title", + grammar: { + task: "title", + } satisfies Grammar, }, }, }, @@ -536,55 +516,3 @@ function useAutoEnhance({ prevOngoingSessionStatus, ]); } - -// function to dynamically generate the grammar for the custom template -function generateCustomGBNF(templateSections: any[]): string { - if (!templateSections || templateSections.length === 0) { - return ""; - } - - function escapeForGBNF(text: string): string { - return text - .replace(/\\/g, "\\\\") - .replace(/"/g, "\\\"") - .replace(/\n/g, "\\n") - .replace(/\r/g, "\\r") - .replace(/\t/g, "\\t"); - } - - const validatedSections = templateSections.map((section, index) => { - let title = section.title || `Section ${index + 1}`; - - title = title - .trim() - .replace(/[\x00-\x1F\x7F]/g, "") - .substring(0, 100); - - return { - ...section, - safeTitle: title || `Section ${index + 1}`, - }; - }); - - const sectionRules = validatedSections.map((section, index) => { - const sectionName = `section${index + 1}`; - const escapedHeader = escapeForGBNF(section.safeTitle); - return `${sectionName} ::= "# ${escapedHeader}\\n\\n" bline bline bline? bline? bline? "\\n"`; - }).join("\n"); - - const sectionNames = validatedSections.map((_, index) => `section${index + 1}`).join(" "); - - const grammar = `root ::= thinking ${sectionNames} - -${sectionRules} - -bline ::= "- **" [^*\\n:]+ "**: " ([^*;,[.\\n] | link)+ ".\\n" - -hsf ::= "- Objective\\n" -hd ::= "- " [A-Z] [^[(*\\n]+ "\\n" -thinking ::= "\\n" hsf hd hd? hd? hd? "" - -link ::= "[" [^\\]]+ "]" "(" [^)]+ ")"`; - - return grammar; -} diff --git a/crates/gbnf/Cargo.toml b/crates/gbnf/Cargo.toml index 6952abaff..df78bbc64 100644 --- a/crates/gbnf/Cargo.toml +++ b/crates/gbnf/Cargo.toml @@ -3,6 +3,10 @@ name = "gbnf" version = "0.1.0" edition = "2021" +[dependencies] +serde = { workspace = true, features = ["derive"] } +specta = { workspace = true, features = ["derive"] } + [dev-dependencies] gbnf-validator = { workspace = true } diff --git a/crates/gbnf/assets/enhance-hypr.gbnf b/crates/gbnf/assets/enhance-hypr.gbnf deleted file mode 100644 index f81d9fd58..000000000 --- a/crates/gbnf/assets/enhance-hypr.gbnf +++ /dev/null @@ -1,5 +0,0 @@ -root ::= think content - -line ::= "- " [A-Z] [^*.\n[(]+ ".\n" -think ::= "\n" line line? line? line? "" -content ::= .* diff --git a/crates/gbnf/assets/enhance-other.gbnf b/crates/gbnf/assets/enhance-other.gbnf deleted file mode 100644 index 22ffc80ef..000000000 --- a/crates/gbnf/assets/enhance-other.gbnf +++ /dev/null @@ -1,14 +0,0 @@ -root ::= thinking sectionf section section section? section? - -sectionf ::= "# Objective\n\n" line line? line? "\n" -section ::= header "\n\n" bline bline bline? bline? bline? "\n" -header ::= "# " [^*.\n]+ - -line ::= "- " [A-Z] [^*.\n[(]+ ".\n" -bline ::= "- **" [A-Z] [^*\n:]+ "**: " ([^*;,[.\n] | link)+ ".\n" - -hsf ::= "- Objective\n" -hd ::= "- " [A-Z] [^[(*\n]+ "\n" -thinking ::= "\n" hsf hd hd? hd? hd? "" - -link ::= "[" [^\]]+ "]" "(" [^)]+ ")" \ No newline at end of file diff --git a/crates/gbnf/assets/tags.gbnf b/crates/gbnf/assets/tags.gbnf deleted file mode 100644 index b17423cc1..000000000 --- a/crates/gbnf/assets/tags.gbnf +++ /dev/null @@ -1,3 +0,0 @@ -root ::= "[" "\"" word "\"" ("," ws "\"" word "\"")* "]" -word ::= [a-zA-Z0-9_-]+ -ws ::= " "* \ No newline at end of file diff --git a/crates/gbnf/assets/title.gbnf b/crates/gbnf/assets/title.gbnf deleted file mode 100644 index 4c9ece904..000000000 --- a/crates/gbnf/assets/title.gbnf +++ /dev/null @@ -1,3 +0,0 @@ -char ::= [A-Za-z0-9] -start ::= [A-Z0-9] -root ::= start char* (" " char+)* diff --git a/crates/gbnf/src/lib.rs b/crates/gbnf/src/lib.rs index d404335fa..12c1411a3 100644 --- a/crates/gbnf/src/lib.rs +++ b/crates/gbnf/src/lib.rs @@ -1,26 +1,102 @@ -pub const ENHANCE_OTHER: &str = include_str!("../assets/enhance-other.gbnf"); -pub const ENHANCE_HYPR: &str = include_str!("../assets/enhance-hypr.gbnf"); -pub const TITLE: &str = include_str!("../assets/title.gbnf"); -pub const TAGS: &str = include_str!("../assets/tags.gbnf"); - -pub enum GBNF { - EnhanceOther, - EnhanceHypr, +#[derive(specta::Type, serde::Serialize, serde::Deserialize)] +#[serde(tag = "task")] +pub enum Grammar { + #[serde(rename = "enhance")] + Enhance { sections: Option> }, + #[serde(rename = "title")] Title, + #[serde(rename = "tags")] Tags, } -impl GBNF { +impl Grammar { pub fn build(&self) -> String { match self { - GBNF::EnhanceOther => ENHANCE_OTHER.to_string(), - GBNF::EnhanceHypr => ENHANCE_HYPR.to_string(), - GBNF::Title => TITLE.to_string(), - GBNF::Tags => TAGS.to_string(), + Grammar::Enhance { sections } => build_enhance_other_grammar(sections), + Grammar::Title => build_title_grammar(), + Grammar::Tags => build_tags_grammar(), } } } +#[allow(dead_code)] +fn build_enhance_hypr_grammar(_s: &Option>) -> String { + vec![ + r##"root ::= think content"##, + r##"line ::= "- " [A-Z] [^*.\n[(]+ ".\n""##, + r##"think ::= "\n" line line? line? line? """##, + r##"content ::= .*"##, + ] + .join("\n") +} + +fn build_known_sections_grammar(sections: &[String]) -> String { + let mut rules = vec![]; + + let mut root_parts = vec![]; + for i in 0..sections.len() { + root_parts.push(format!("section{}", i)); + } + + let root_rule = format!("root ::= {}", root_parts.join(" ")); + rules.push(root_rule); + + for (i, section) in sections.iter().enumerate() { + let section_rule = format!( + r##"section{} ::= "# {}\n\n" bline bline bline? bline? bline? "\n""##, + i, section + ); + rules.push(section_rule); + } + + rules + .push(r##"bline ::= "- **" [A-Z] [^*\n:]+ "**: " ([^*;,[.\n] | link)+ ".\n""##.to_string()); + rules.push(r##"link ::= "[" [^\]]+ "]" "(" [^)]+ ")""##.to_string()); + + rules.join("\n") +} + +fn build_enhance_other_grammar(s: &Option>) -> String { + let auto = vec![ + r##"root ::= thinking sectionf section section section? section?"##, + r##"sectionf ::= "# Objective\n\n" line line? line? "\n""##, + r##"section ::= header "\n\n" bline bline bline? bline? bline? "\n""##, + r##"header ::= "# " [^*.\n]+"##, + r##"line ::= "- " [A-Z] [^*.\n[(]+ ".\n""##, + r##"bline ::= "- **" [A-Z] [^*\n:]+ "**: " ([^*;,[.\n] | link)+ ".\n""##, + r##"hsf ::= "- Objective\n""##, + r##"hd ::= "- " [A-Z] [^[(*\n]+ "\n""##, + r##"thinking ::= "\n" hsf hd hd? hd? hd? """##, + r##"link ::= "[" [^\]]+ "]" "(" [^)]+ ")""##, + ] + .join("\n"); + + match s { + None => auto, + Some(v) if v.is_empty() => auto, + Some(v) => build_known_sections_grammar(v), + } +} + +fn build_title_grammar() -> String { + vec![ + r##"lowercase ::= [a-z0-9]"##, + r##"uppercase ::= [A-Z]"##, + r##"word ::= uppercase lowercase*"##, + r##"root ::= word (" " word)*"##, + ] + .join("\n") +} + +fn build_tags_grammar() -> String { + vec![ + r##"root ::= \"[\" \"\" word \"\" (\",\" ws \"\" word \"\")* \"]\""##, + r##"word ::= [a-zA-Z0-9_-]+"##, + r##"ws ::= \" \"*"##, + ] + .join("\n") +} + #[cfg(test)] mod tests { use super::*; @@ -40,7 +116,7 @@ mod tests { ("Meeting-Summary", false), ("", false), ] { - let result = gbnf.validate(TITLE, input).unwrap(); + let result = gbnf.validate(&build_title_grammar(), input).unwrap(); assert_eq!(result, expected, "failed: {}", input); } } @@ -53,7 +129,7 @@ mod tests { ("[\"meeting\", \"summary\"]", true), ("[\"meeting\", \"summary\", \"\"]", false), ] { - let result = gbnf.validate(TAGS, input).unwrap(); + let result = gbnf.validate(&build_tags_grammar(), input).unwrap(); assert_eq!(result, expected, "failed: {}", input); } } diff --git a/crates/llama/src/lib.rs b/crates/llama/src/lib.rs index 0ce59bd98..3f10588d3 100644 --- a/crates/llama/src/lib.rs +++ b/crates/llama/src/lib.rs @@ -391,7 +391,7 @@ mod tests { let llama = get_model(); let request = LlamaRequest { - grammar: Some(hypr_gbnf::GBNF::EnhanceOther.build()), + grammar: Some(hypr_gbnf::Grammar::Enhance { sections: None }.build()), messages: vec![ LlamaChatMessage::new( "system".into(), diff --git a/plugins/local-llm/src/server.rs b/plugins/local-llm/src/server.rs index b0b05f37a..8ef8f9de9 100644 --- a/plugins/local-llm/src/server.rs +++ b/plugins/local-llm/src/server.rs @@ -267,7 +267,12 @@ fn build_response( .map(hypr_llama::FromOpenAI::from_openai) .collect(); - let grammar = select_grammar(&request, &model.name); + let grammar = request + .metadata + .as_ref() + .and_then(|v| v.get("grammar")) + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .map(|g| g.build()); let request = hypr_llama::LlamaRequest { messages, grammar }; @@ -323,29 +328,3 @@ fn build_mock_response() -> Pin StreamEvent::Content(chunk) })) } - -fn select_grammar( - request: &CreateChatCompletionRequest, - model_name: &hypr_llama::ModelName, -) -> Option { - let grammar = request - .metadata - .as_ref() - .and_then(|v| v.get("grammar").and_then(|v| v.as_str())); - - let custom_grammar = request - .metadata - .as_ref() - .and_then(|v| v.get("customGrammar").and_then(|v| v.as_str())); - - match grammar { - Some("enhance") => match model_name { - hypr_llama::ModelName::HyprLLM => Some(hypr_gbnf::GBNF::EnhanceHypr.build()), - _ => Some(hypr_gbnf::GBNF::EnhanceOther.build()), - }, - Some("title") => Some(hypr_gbnf::GBNF::Title.build()), - Some("tags") => Some(hypr_gbnf::GBNF::Tags.build()), - Some("custom") => custom_grammar.map(|s| s.to_string()), - _ => None, - } -} diff --git a/plugins/template/Cargo.toml b/plugins/template/Cargo.toml index d2eecac4c..faa625c3e 100644 --- a/plugins/template/Cargo.toml +++ b/plugins/template/Cargo.toml @@ -15,7 +15,9 @@ insta = { workspace = true } specta-typescript = { workspace = true } [dependencies] +hypr-gbnf = { workspace = true } hypr-template = { workspace = true } + serde_json = { workspace = true } tracing = { workspace = true } diff --git a/plugins/template/js/bindings.gen.ts b/plugins/template/js/bindings.gen.ts index c880599b4..6cb4a5883 100644 --- a/plugins/template/js/bindings.gen.ts +++ b/plugins/template/js/bindings.gen.ts @@ -25,6 +25,7 @@ async registerTemplate(name: string, template: string) : Promise { /** user-defined types **/ +export type Grammar = { task: "enhance"; sections: string[] | null } | { task: "title" } | { task: "tags" } export type JsonValue = null | boolean | number | string | JsonValue[] | Partial<{ [key in string]: JsonValue }> /** tauri-specta globals **/ diff --git a/plugins/template/src/lib.rs b/plugins/template/src/lib.rs index 7e85e42c0..d77f4e1c2 100644 --- a/plugins/template/src/lib.rs +++ b/plugins/template/src/lib.rs @@ -29,6 +29,7 @@ fn make_specta_builder() -> tauri_specta::Builder { commands::render::, commands::register_template::, ]) + .typ::() .error_handling(tauri_specta::ErrorHandlingMode::Throw) }