Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

114 changes: 21 additions & 93 deletions apps/desktop/src/components/editor-area/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { commands as analyticsCommands } from "@hypr/plugin-analytics";
import { commands as connectorCommands } from "@hypr/plugin-connector";
import { commands as dbCommands } from "@hypr/plugin-db";
import { commands as miscCommands } from "@hypr/plugin-misc";
import { commands as templateCommands } from "@hypr/plugin-template";
import { commands as templateCommands, type Grammar } from "@hypr/plugin-template";
import Editor, { type TiptapEditor } from "@hypr/tiptap/editor";
import Renderer from "@hypr/tiptap/renderer";
import { extractHashtags } from "@hypr/tiptap/shared";
Expand Down Expand Up @@ -282,48 +282,28 @@ export function useEnhanceMutation({
return;
}

// Get current config for default template
const config = await dbCommands.getConfig();

// Use provided templateId or fall back to config
const effectiveTemplateId = templateId !== undefined
? templateId
: config.general?.selected_template_id;
const getTemplate = async () => {
const effectiveTemplateId = templateId !== undefined
? templateId
: config.general?.selected_template_id;

let templateInfo = "";
let customGrammar: string | null = null;
if (!effectiveTemplateId) {
return null;
}

if (effectiveTemplateId) {
const templates = await dbCommands.listTemplates();
const selectedTemplate = templates.find(t => t.id === effectiveTemplateId);

if (selectedTemplate) {
if (selectedTemplate.sections && selectedTemplate.sections.length > 0) {
customGrammar = generateCustomGBNF(selectedTemplate.sections);
}

templateInfo = `
SELECTED TEMPLATE:
Template Title: ${selectedTemplate.title || "Untitled"}
Template Description: ${selectedTemplate.description || "No description"}
return templates.find(t => t.id === effectiveTemplateId) || null;
};

Sections:`;

selectedTemplate.sections?.forEach((section, index) => {
templateInfo += `
${index + 1}. ${section.title || "Untitled Section"}
└─ ${section.description || "No description"}`;
});
}
} else {
console.log("Using default template (no custom template selected)");
}
const selectedTemplate = await getTemplate();

const participants = await dbCommands.sessionListParticipants(sessionId);

const systemMessage = await templateCommands.render(
"enhance.system",
{ config, type, templateInfo },
{ config, type, templateInfo: selectedTemplate },
);

const userMessage = await templateCommands.render(
Expand Down Expand Up @@ -375,14 +355,12 @@ Sections:`;
...(freshIsLocalLlm && {
providerOptions: {
[localProviderName]: {
metadata: customGrammar
? {
grammar: "custom",
customGrammar: customGrammar,
}
: {
grammar: "enhance",
},
metadata: {
grammar: {
task: "enhance",
sections: selectedTemplate?.sections.map(s => s.title) || null,
} satisfies Grammar,
},
},
},
}),
Expand Down Expand Up @@ -478,7 +456,9 @@ function useGenerateTitleMutation({ sessionId }: { sessionId: string }) {
providerOptions: {
[localProviderName]: {
metadata: {
grammar: "title",
grammar: {
task: "title",
} satisfies Grammar,
},
},
},
Expand Down Expand Up @@ -536,55 +516,3 @@ function useAutoEnhance({
prevOngoingSessionStatus,
]);
}

// function to dynamically generate the grammar for the custom template
function generateCustomGBNF(templateSections: any[]): string {
if (!templateSections || templateSections.length === 0) {
return "";
}

function escapeForGBNF(text: string): string {
return text
.replace(/\\/g, "\\\\")
.replace(/"/g, "\\\"")
.replace(/\n/g, "\\n")
.replace(/\r/g, "\\r")
.replace(/\t/g, "\\t");
}

const validatedSections = templateSections.map((section, index) => {
let title = section.title || `Section ${index + 1}`;

title = title
.trim()
.replace(/[\x00-\x1F\x7F]/g, "")
.substring(0, 100);

return {
...section,
safeTitle: title || `Section ${index + 1}`,
};
});

const sectionRules = validatedSections.map((section, index) => {
const sectionName = `section${index + 1}`;
const escapedHeader = escapeForGBNF(section.safeTitle);
return `${sectionName} ::= "# ${escapedHeader}\\n\\n" bline bline bline? bline? bline? "\\n"`;
}).join("\n");

const sectionNames = validatedSections.map((_, index) => `section${index + 1}`).join(" ");

const grammar = `root ::= thinking ${sectionNames}

${sectionRules}

bline ::= "- **" [^*\\n:]+ "**: " ([^*;,[.\\n] | link)+ ".\\n"

hsf ::= "- Objective\\n"
hd ::= "- " [A-Z] [^[(*\\n]+ "\\n"
thinking ::= "<thinking>\\n" hsf hd hd? hd? hd? "</thinking>"

link ::= "[" [^\\]]+ "]" "(" [^)]+ ")"`;

return grammar;
}
4 changes: 4 additions & 0 deletions crates/gbnf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ name = "gbnf"
version = "0.1.0"
edition = "2021"

[dependencies]
serde = { workspace = true, features = ["derive"] }
specta = { workspace = true, features = ["derive"] }

[dev-dependencies]
gbnf-validator = { workspace = true }

Expand Down
5 changes: 0 additions & 5 deletions crates/gbnf/assets/enhance-hypr.gbnf

This file was deleted.

14 changes: 0 additions & 14 deletions crates/gbnf/assets/enhance-other.gbnf

This file was deleted.

3 changes: 0 additions & 3 deletions crates/gbnf/assets/tags.gbnf

This file was deleted.

3 changes: 0 additions & 3 deletions crates/gbnf/assets/title.gbnf

This file was deleted.

106 changes: 91 additions & 15 deletions crates/gbnf/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,102 @@
pub const ENHANCE_OTHER: &str = include_str!("../assets/enhance-other.gbnf");
pub const ENHANCE_HYPR: &str = include_str!("../assets/enhance-hypr.gbnf");
pub const TITLE: &str = include_str!("../assets/title.gbnf");
pub const TAGS: &str = include_str!("../assets/tags.gbnf");

pub enum GBNF {
EnhanceOther,
EnhanceHypr,
#[derive(specta::Type, serde::Serialize, serde::Deserialize)]
#[serde(tag = "task")]
pub enum Grammar {
#[serde(rename = "enhance")]
Enhance { sections: Option<Vec<String>> },
#[serde(rename = "title")]
Title,
#[serde(rename = "tags")]
Tags,
}

impl GBNF {
impl Grammar {
pub fn build(&self) -> String {
match self {
GBNF::EnhanceOther => ENHANCE_OTHER.to_string(),
GBNF::EnhanceHypr => ENHANCE_HYPR.to_string(),
GBNF::Title => TITLE.to_string(),
GBNF::Tags => TAGS.to_string(),
Grammar::Enhance { sections } => build_enhance_other_grammar(sections),
Grammar::Title => build_title_grammar(),
Grammar::Tags => build_tags_grammar(),
}
}
}

#[allow(dead_code)]
fn build_enhance_hypr_grammar(_s: &Option<Vec<String>>) -> String {
vec![
r##"root ::= think content"##,
r##"line ::= "- " [A-Z] [^*.\n[(]+ ".\n""##,
r##"think ::= "<think>\n" line line? line? line? "</think>""##,
r##"content ::= .*"##,
]
.join("\n")
}

fn build_known_sections_grammar(sections: &[String]) -> String {
let mut rules = vec![];

let mut root_parts = vec![];
for i in 0..sections.len() {
root_parts.push(format!("section{}", i));
}

let root_rule = format!("root ::= {}", root_parts.join(" "));
rules.push(root_rule);

for (i, section) in sections.iter().enumerate() {
let section_rule = format!(
r##"section{} ::= "# {}\n\n" bline bline bline? bline? bline? "\n""##,
i, section
);
rules.push(section_rule);
}

rules
.push(r##"bline ::= "- **" [A-Z] [^*\n:]+ "**: " ([^*;,[.\n] | link)+ ".\n""##.to_string());
rules.push(r##"link ::= "[" [^\]]+ "]" "(" [^)]+ ")""##.to_string());

rules.join("\n")
}

fn build_enhance_other_grammar(s: &Option<Vec<String>>) -> String {
let auto = vec![
r##"root ::= thinking sectionf section section section? section?"##,
r##"sectionf ::= "# Objective\n\n" line line? line? "\n""##,
r##"section ::= header "\n\n" bline bline bline? bline? bline? "\n""##,
r##"header ::= "# " [^*.\n]+"##,
r##"line ::= "- " [A-Z] [^*.\n[(]+ ".\n""##,
r##"bline ::= "- **" [A-Z] [^*\n:]+ "**: " ([^*;,[.\n] | link)+ ".\n""##,
r##"hsf ::= "- Objective\n""##,
r##"hd ::= "- " [A-Z] [^[(*\n]+ "\n""##,
r##"thinking ::= "<thinking>\n" hsf hd hd? hd? hd? "</thinking>""##,
r##"link ::= "[" [^\]]+ "]" "(" [^)]+ ")""##,
]
.join("\n");

match s {
None => auto,
Some(v) if v.is_empty() => auto,
Some(v) => build_known_sections_grammar(v),
}
}

fn build_title_grammar() -> String {
vec![
r##"lowercase ::= [a-z0-9]"##,
r##"uppercase ::= [A-Z]"##,
r##"word ::= uppercase lowercase*"##,
r##"root ::= word (" " word)*"##,
]
.join("\n")
}

fn build_tags_grammar() -> String {
vec![
r##"root ::= \"[\" \"\" word \"\" (\",\" ws \"\" word \"\")* \"]\""##,
r##"word ::= [a-zA-Z0-9_-]+"##,
r##"ws ::= \" \"*"##,
]
.join("\n")
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -40,7 +116,7 @@ mod tests {
("Meeting-Summary", false),
("", false),
] {
let result = gbnf.validate(TITLE, input).unwrap();
let result = gbnf.validate(&build_title_grammar(), input).unwrap();
assert_eq!(result, expected, "failed: {}", input);
}
}
Expand All @@ -53,7 +129,7 @@ mod tests {
("[\"meeting\", \"summary\"]", true),
("[\"meeting\", \"summary\", \"\"]", false),
] {
let result = gbnf.validate(TAGS, input).unwrap();
let result = gbnf.validate(&build_tags_grammar(), input).unwrap();
assert_eq!(result, expected, "failed: {}", input);
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ mod tests {
let llama = get_model();

let request = LlamaRequest {
grammar: Some(hypr_gbnf::GBNF::EnhanceOther.build()),
grammar: Some(hypr_gbnf::Grammar::Enhance { sections: None }.build()),
messages: vec![
LlamaChatMessage::new(
"system".into(),
Expand Down
Loading
Loading