diff --git a/Cargo.lock b/Cargo.lock
index 7c8a52b11..2525043f7 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4634,6 +4634,8 @@ dependencies = [
"gbnf-validator",
"indoc",
"insta",
+ "serde",
+ "specta",
]
[[package]]
@@ -14375,6 +14377,7 @@ dependencies = [
name = "tauri-plugin-template"
version = "0.1.0"
dependencies = [
+ "gbnf",
"insta",
"serde_json",
"specta",
diff --git a/apps/desktop/src/components/editor-area/index.tsx b/apps/desktop/src/components/editor-area/index.tsx
index e5a6909b4..8fa50a771 100644
--- a/apps/desktop/src/components/editor-area/index.tsx
+++ b/apps/desktop/src/components/editor-area/index.tsx
@@ -12,7 +12,7 @@ import { commands as analyticsCommands } from "@hypr/plugin-analytics";
import { commands as connectorCommands } from "@hypr/plugin-connector";
import { commands as dbCommands } from "@hypr/plugin-db";
import { commands as miscCommands } from "@hypr/plugin-misc";
-import { commands as templateCommands } from "@hypr/plugin-template";
+import { commands as templateCommands, type Grammar } from "@hypr/plugin-template";
import Editor, { type TiptapEditor } from "@hypr/tiptap/editor";
import Renderer from "@hypr/tiptap/renderer";
import { extractHashtags } from "@hypr/tiptap/shared";
@@ -282,48 +282,28 @@ export function useEnhanceMutation({
return;
}
- // Get current config for default template
const config = await dbCommands.getConfig();
- // Use provided templateId or fall back to config
- const effectiveTemplateId = templateId !== undefined
- ? templateId
- : config.general?.selected_template_id;
+ const getTemplate = async () => {
+ const effectiveTemplateId = templateId !== undefined
+ ? templateId
+ : config.general?.selected_template_id;
- let templateInfo = "";
- let customGrammar: string | null = null;
+ if (!effectiveTemplateId) {
+ return null;
+ }
- if (effectiveTemplateId) {
const templates = await dbCommands.listTemplates();
- const selectedTemplate = templates.find(t => t.id === effectiveTemplateId);
-
- if (selectedTemplate) {
- if (selectedTemplate.sections && selectedTemplate.sections.length > 0) {
- customGrammar = generateCustomGBNF(selectedTemplate.sections);
- }
-
- templateInfo = `
-SELECTED TEMPLATE:
-Template Title: ${selectedTemplate.title || "Untitled"}
-Template Description: ${selectedTemplate.description || "No description"}
+ return templates.find(t => t.id === effectiveTemplateId) || null;
+ };
-Sections:`;
-
- selectedTemplate.sections?.forEach((section, index) => {
- templateInfo += `
- ${index + 1}. ${section.title || "Untitled Section"}
- └─ ${section.description || "No description"}`;
- });
- }
- } else {
- console.log("Using default template (no custom template selected)");
- }
+ const selectedTemplate = await getTemplate();
const participants = await dbCommands.sessionListParticipants(sessionId);
const systemMessage = await templateCommands.render(
"enhance.system",
- { config, type, templateInfo },
+ { config, type, templateInfo: selectedTemplate },
);
const userMessage = await templateCommands.render(
@@ -375,14 +355,12 @@ Sections:`;
...(freshIsLocalLlm && {
providerOptions: {
[localProviderName]: {
- metadata: customGrammar
- ? {
- grammar: "custom",
- customGrammar: customGrammar,
- }
- : {
- grammar: "enhance",
- },
+ metadata: {
+ grammar: {
+ task: "enhance",
+ sections: selectedTemplate?.sections.map(s => s.title) || null,
+ } satisfies Grammar,
+ },
},
},
}),
@@ -478,7 +456,9 @@ function useGenerateTitleMutation({ sessionId }: { sessionId: string }) {
providerOptions: {
[localProviderName]: {
metadata: {
- grammar: "title",
+ grammar: {
+ task: "title",
+ } satisfies Grammar,
},
},
},
@@ -536,55 +516,3 @@ function useAutoEnhance({
prevOngoingSessionStatus,
]);
}
-
-// function to dynamically generate the grammar for the custom template
-function generateCustomGBNF(templateSections: any[]): string {
- if (!templateSections || templateSections.length === 0) {
- return "";
- }
-
- function escapeForGBNF(text: string): string {
- return text
- .replace(/\\/g, "\\\\")
- .replace(/"/g, "\\\"")
- .replace(/\n/g, "\\n")
- .replace(/\r/g, "\\r")
- .replace(/\t/g, "\\t");
- }
-
- const validatedSections = templateSections.map((section, index) => {
- let title = section.title || `Section ${index + 1}`;
-
- title = title
- .trim()
- .replace(/[\x00-\x1F\x7F]/g, "")
- .substring(0, 100);
-
- return {
- ...section,
- safeTitle: title || `Section ${index + 1}`,
- };
- });
-
- const sectionRules = validatedSections.map((section, index) => {
- const sectionName = `section${index + 1}`;
- const escapedHeader = escapeForGBNF(section.safeTitle);
- return `${sectionName} ::= "# ${escapedHeader}\\n\\n" bline bline bline? bline? bline? "\\n"`;
- }).join("\n");
-
- const sectionNames = validatedSections.map((_, index) => `section${index + 1}`).join(" ");
-
- const grammar = `root ::= thinking ${sectionNames}
-
-${sectionRules}
-
-bline ::= "- **" [^*\\n:]+ "**: " ([^*;,[.\\n] | link)+ ".\\n"
-
-hsf ::= "- Objective\\n"
-hd ::= "- " [A-Z] [^[(*\\n]+ "\\n"
-thinking ::= "\\n" hsf hd hd? hd? hd? ""
-
-link ::= "[" [^\\]]+ "]" "(" [^)]+ ")"`;
-
- return grammar;
-}
diff --git a/crates/gbnf/Cargo.toml b/crates/gbnf/Cargo.toml
index 6952abaff..df78bbc64 100644
--- a/crates/gbnf/Cargo.toml
+++ b/crates/gbnf/Cargo.toml
@@ -3,6 +3,10 @@ name = "gbnf"
version = "0.1.0"
edition = "2021"
+[dependencies]
+serde = { workspace = true, features = ["derive"] }
+specta = { workspace = true, features = ["derive"] }
+
[dev-dependencies]
gbnf-validator = { workspace = true }
diff --git a/crates/gbnf/assets/enhance-hypr.gbnf b/crates/gbnf/assets/enhance-hypr.gbnf
deleted file mode 100644
index f81d9fd58..000000000
--- a/crates/gbnf/assets/enhance-hypr.gbnf
+++ /dev/null
@@ -1,5 +0,0 @@
-root ::= think content
-
-line ::= "- " [A-Z] [^*.\n[(]+ ".\n"
-think ::= "\n" line line? line? line? ""
-content ::= .*
diff --git a/crates/gbnf/assets/enhance-other.gbnf b/crates/gbnf/assets/enhance-other.gbnf
deleted file mode 100644
index 22ffc80ef..000000000
--- a/crates/gbnf/assets/enhance-other.gbnf
+++ /dev/null
@@ -1,14 +0,0 @@
-root ::= thinking sectionf section section section? section?
-
-sectionf ::= "# Objective\n\n" line line? line? "\n"
-section ::= header "\n\n" bline bline bline? bline? bline? "\n"
-header ::= "# " [^*.\n]+
-
-line ::= "- " [A-Z] [^*.\n[(]+ ".\n"
-bline ::= "- **" [A-Z] [^*\n:]+ "**: " ([^*;,[.\n] | link)+ ".\n"
-
-hsf ::= "- Objective\n"
-hd ::= "- " [A-Z] [^[(*\n]+ "\n"
-thinking ::= "\n" hsf hd hd? hd? hd? ""
-
-link ::= "[" [^\]]+ "]" "(" [^)]+ ")"
\ No newline at end of file
diff --git a/crates/gbnf/assets/tags.gbnf b/crates/gbnf/assets/tags.gbnf
deleted file mode 100644
index b17423cc1..000000000
--- a/crates/gbnf/assets/tags.gbnf
+++ /dev/null
@@ -1,3 +0,0 @@
-root ::= "[" "\"" word "\"" ("," ws "\"" word "\"")* "]"
-word ::= [a-zA-Z0-9_-]+
-ws ::= " "*
\ No newline at end of file
diff --git a/crates/gbnf/assets/title.gbnf b/crates/gbnf/assets/title.gbnf
deleted file mode 100644
index 4c9ece904..000000000
--- a/crates/gbnf/assets/title.gbnf
+++ /dev/null
@@ -1,3 +0,0 @@
-char ::= [A-Za-z0-9]
-start ::= [A-Z0-9]
-root ::= start char* (" " char+)*
diff --git a/crates/gbnf/src/lib.rs b/crates/gbnf/src/lib.rs
index d404335fa..12c1411a3 100644
--- a/crates/gbnf/src/lib.rs
+++ b/crates/gbnf/src/lib.rs
@@ -1,26 +1,102 @@
-pub const ENHANCE_OTHER: &str = include_str!("../assets/enhance-other.gbnf");
-pub const ENHANCE_HYPR: &str = include_str!("../assets/enhance-hypr.gbnf");
-pub const TITLE: &str = include_str!("../assets/title.gbnf");
-pub const TAGS: &str = include_str!("../assets/tags.gbnf");
-
-pub enum GBNF {
- EnhanceOther,
- EnhanceHypr,
+#[derive(specta::Type, serde::Serialize, serde::Deserialize)]
+#[serde(tag = "task")]
+pub enum Grammar {
+ #[serde(rename = "enhance")]
+ Enhance { sections: Option> },
+ #[serde(rename = "title")]
Title,
+ #[serde(rename = "tags")]
Tags,
}
-impl GBNF {
+impl Grammar {
pub fn build(&self) -> String {
match self {
- GBNF::EnhanceOther => ENHANCE_OTHER.to_string(),
- GBNF::EnhanceHypr => ENHANCE_HYPR.to_string(),
- GBNF::Title => TITLE.to_string(),
- GBNF::Tags => TAGS.to_string(),
+ Grammar::Enhance { sections } => build_enhance_other_grammar(sections),
+ Grammar::Title => build_title_grammar(),
+ Grammar::Tags => build_tags_grammar(),
}
}
}
+#[allow(dead_code)]
+fn build_enhance_hypr_grammar(_s: &Option>) -> String {
+ vec![
+ r##"root ::= think content"##,
+ r##"line ::= "- " [A-Z] [^*.\n[(]+ ".\n""##,
+ r##"think ::= "\n" line line? line? line? """##,
+ r##"content ::= .*"##,
+ ]
+ .join("\n")
+}
+
+fn build_known_sections_grammar(sections: &[String]) -> String {
+ let mut rules = vec![];
+
+ let mut root_parts = vec![];
+ for i in 0..sections.len() {
+ root_parts.push(format!("section{}", i));
+ }
+
+ let root_rule = format!("root ::= {}", root_parts.join(" "));
+ rules.push(root_rule);
+
+ for (i, section) in sections.iter().enumerate() {
+ let section_rule = format!(
+ r##"section{} ::= "# {}\n\n" bline bline bline? bline? bline? "\n""##,
+ i, section
+ );
+ rules.push(section_rule);
+ }
+
+ rules
+ .push(r##"bline ::= "- **" [A-Z] [^*\n:]+ "**: " ([^*;,[.\n] | link)+ ".\n""##.to_string());
+ rules.push(r##"link ::= "[" [^\]]+ "]" "(" [^)]+ ")""##.to_string());
+
+ rules.join("\n")
+}
+
+fn build_enhance_other_grammar(s: &Option>) -> String {
+ let auto = vec![
+ r##"root ::= thinking sectionf section section section? section?"##,
+ r##"sectionf ::= "# Objective\n\n" line line? line? "\n""##,
+ r##"section ::= header "\n\n" bline bline bline? bline? bline? "\n""##,
+ r##"header ::= "# " [^*.\n]+"##,
+ r##"line ::= "- " [A-Z] [^*.\n[(]+ ".\n""##,
+ r##"bline ::= "- **" [A-Z] [^*\n:]+ "**: " ([^*;,[.\n] | link)+ ".\n""##,
+ r##"hsf ::= "- Objective\n""##,
+ r##"hd ::= "- " [A-Z] [^[(*\n]+ "\n""##,
+ r##"thinking ::= "\n" hsf hd hd? hd? hd? """##,
+ r##"link ::= "[" [^\]]+ "]" "(" [^)]+ ")""##,
+ ]
+ .join("\n");
+
+ match s {
+ None => auto,
+ Some(v) if v.is_empty() => auto,
+ Some(v) => build_known_sections_grammar(v),
+ }
+}
+
+fn build_title_grammar() -> String {
+ vec![
+ r##"lowercase ::= [a-z0-9]"##,
+ r##"uppercase ::= [A-Z]"##,
+ r##"word ::= uppercase lowercase*"##,
+ r##"root ::= word (" " word)*"##,
+ ]
+ .join("\n")
+}
+
+fn build_tags_grammar() -> String {
+ vec![
+ r##"root ::= \"[\" \"\" word \"\" (\",\" ws \"\" word \"\")* \"]\""##,
+ r##"word ::= [a-zA-Z0-9_-]+"##,
+ r##"ws ::= \" \"*"##,
+ ]
+ .join("\n")
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -40,7 +116,7 @@ mod tests {
("Meeting-Summary", false),
("", false),
] {
- let result = gbnf.validate(TITLE, input).unwrap();
+ let result = gbnf.validate(&build_title_grammar(), input).unwrap();
assert_eq!(result, expected, "failed: {}", input);
}
}
@@ -53,7 +129,7 @@ mod tests {
("[\"meeting\", \"summary\"]", true),
("[\"meeting\", \"summary\", \"\"]", false),
] {
- let result = gbnf.validate(TAGS, input).unwrap();
+ let result = gbnf.validate(&build_tags_grammar(), input).unwrap();
assert_eq!(result, expected, "failed: {}", input);
}
}
diff --git a/crates/llama/src/lib.rs b/crates/llama/src/lib.rs
index 0ce59bd98..3f10588d3 100644
--- a/crates/llama/src/lib.rs
+++ b/crates/llama/src/lib.rs
@@ -391,7 +391,7 @@ mod tests {
let llama = get_model();
let request = LlamaRequest {
- grammar: Some(hypr_gbnf::GBNF::EnhanceOther.build()),
+ grammar: Some(hypr_gbnf::Grammar::Enhance { sections: None }.build()),
messages: vec![
LlamaChatMessage::new(
"system".into(),
diff --git a/plugins/local-llm/src/server.rs b/plugins/local-llm/src/server.rs
index b0b05f37a..8ef8f9de9 100644
--- a/plugins/local-llm/src/server.rs
+++ b/plugins/local-llm/src/server.rs
@@ -267,7 +267,12 @@ fn build_response(
.map(hypr_llama::FromOpenAI::from_openai)
.collect();
- let grammar = select_grammar(&request, &model.name);
+ let grammar = request
+ .metadata
+ .as_ref()
+ .and_then(|v| v.get("grammar"))
+ .and_then(|v| serde_json::from_value::(v.clone()).ok())
+ .map(|g| g.build());
let request = hypr_llama::LlamaRequest { messages, grammar };
@@ -323,29 +328,3 @@ fn build_mock_response() -> Pin
StreamEvent::Content(chunk)
}))
}
-
-fn select_grammar(
- request: &CreateChatCompletionRequest,
- model_name: &hypr_llama::ModelName,
-) -> Option {
- let grammar = request
- .metadata
- .as_ref()
- .and_then(|v| v.get("grammar").and_then(|v| v.as_str()));
-
- let custom_grammar = request
- .metadata
- .as_ref()
- .and_then(|v| v.get("customGrammar").and_then(|v| v.as_str()));
-
- match grammar {
- Some("enhance") => match model_name {
- hypr_llama::ModelName::HyprLLM => Some(hypr_gbnf::GBNF::EnhanceHypr.build()),
- _ => Some(hypr_gbnf::GBNF::EnhanceOther.build()),
- },
- Some("title") => Some(hypr_gbnf::GBNF::Title.build()),
- Some("tags") => Some(hypr_gbnf::GBNF::Tags.build()),
- Some("custom") => custom_grammar.map(|s| s.to_string()),
- _ => None,
- }
-}
diff --git a/plugins/template/Cargo.toml b/plugins/template/Cargo.toml
index d2eecac4c..faa625c3e 100644
--- a/plugins/template/Cargo.toml
+++ b/plugins/template/Cargo.toml
@@ -15,7 +15,9 @@ insta = { workspace = true }
specta-typescript = { workspace = true }
[dependencies]
+hypr-gbnf = { workspace = true }
hypr-template = { workspace = true }
+
serde_json = { workspace = true }
tracing = { workspace = true }
diff --git a/plugins/template/js/bindings.gen.ts b/plugins/template/js/bindings.gen.ts
index c880599b4..6cb4a5883 100644
--- a/plugins/template/js/bindings.gen.ts
+++ b/plugins/template/js/bindings.gen.ts
@@ -25,6 +25,7 @@ async registerTemplate(name: string, template: string) : Promise {
/** user-defined types **/
+export type Grammar = { task: "enhance"; sections: string[] | null } | { task: "title" } | { task: "tags" }
export type JsonValue = null | boolean | number | string | JsonValue[] | Partial<{ [key in string]: JsonValue }>
/** tauri-specta globals **/
diff --git a/plugins/template/src/lib.rs b/plugins/template/src/lib.rs
index 7e85e42c0..d77f4e1c2 100644
--- a/plugins/template/src/lib.rs
+++ b/plugins/template/src/lib.rs
@@ -29,6 +29,7 @@ fn make_specta_builder() -> tauri_specta::Builder {
commands::render::,
commands::register_template::,
])
+ .typ::()
.error_handling(tauri_specta::ErrorHandlingMode::Throw)
}