From 408e74ece252fe96685b18c79942e481e14a1fe9 Mon Sep 17 00:00:00 2001 From: Fraxy V Date: Tue, 7 May 2024 18:00:25 +0300 Subject: [PATCH] whisper grammar: experimental implementation with boost::spirit --- examples/CMakeLists.txt | 2 + examples/grammar-parser.h | 1 + examples/test-grammar-parser.cpp | 250 +++++++++++++++++++++++++ examples/test-grammar/CMakeLists.txt | 10 + examples/test-grammar/test-grammar.cpp | 110 +++++++++++ 5 files changed, 373 insertions(+) create mode 100644 examples/test-grammar-parser.cpp create mode 100644 examples/test-grammar/CMakeLists.txt create mode 100644 examples/test-grammar/test-grammar.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 104482f2133..0b2df9c0a16 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -29,6 +29,7 @@ add_library(${TARGET} STATIC common-ggml.cpp grammar-parser.h grammar-parser.cpp + test-grammar-parser.cpp ) include(DefaultTargetOptions) @@ -108,6 +109,7 @@ if (WHISPER_SDL2) set_target_properties(sycl PROPERTIES FOLDER "examples") endif() endif (WHISPER_SDL2) + add_subdirectory(test-grammar) endif() if (WHISPER_SDL2) diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h index 47d019c33e1..e6f11808d95 100644 --- a/examples/grammar-parser.h +++ b/examples/grammar-parser.h @@ -26,4 +26,5 @@ namespace grammar_parser { parse_state parse(const char * src); void print_grammar(FILE * file, const parse_state & state); + std::vector> test_parse(const std::string& src); } diff --git a/examples/test-grammar-parser.cpp b/examples/test-grammar-parser.cpp new file mode 100644 index 00000000000..68910e22220 --- /dev/null +++ b/examples/test-grammar-parser.cpp @@ -0,0 +1,250 @@ +#include "grammar-parser.h" +#include + +#include +#include + +namespace qi = boost::spirit::qi; +namespace phx = boost::phoenix; + +namespace { +using Iterator = std::string::const_iterator; +using Skipper = boost::spirit::qi::rule; +using Rule = std::vector; +using Rules = std::vector; + +struct WGrammar : public qi::grammar { + WGrammar() : WGrammar::base_type(grammar) { + using namespace qi; + grammar = (omit[*space] >> *rule) + [_val = phx::bind(&WGrammar::get_rules, this)]; + rule = (rule_name >> omit[*blank] >> "::=" >> omit[*space] >> alternates >> (+eol | eoi)) + [phx::bind(&WGrammar::add_rule, this, _1, _2)]; + rule_name = (alpha >> *(alnum | char_('-'))) + [_val = phx::bind(&WGrammar::get_symbol_id, this, _1, _2)]; + alternates = (sequence >> omit[*blank] >> *( omit[*blank] >> '|' >> omit[*space] >> sequence)) + [_val = phx::bind(&WGrammar::add_alternate, this, _1, _2)]; + nested_alternates = (nested_sequence >> omit[*space] >> *(omit[*space] >> '|' >> omit[*space] >> nested_sequence)) + [_val = phx::bind(&WGrammar::add_alternate, this, _1, _2)]; + sequence = (repetition % *blank) + [_val = phx::bind(&WGrammar::merge, this, _1)]; + nested_sequence = (repetition % *space) + [_val = phx::bind(&WGrammar::merge, this, _1)]; + repetition = (value >> -char_("*+?")) + [_val = phx::bind(&WGrammar::add_repetition, this, _1, _2)]; + value = literal[_val = _1] | set[_val = _1] | + rule_name[_val = phx::bind(&WGrammar::add_rule_ref, this, _1)] | + group[_val = phx::bind(&WGrammar::add_rule_ref, this, _1)]; + literal = lexeme['"' >> *(char_ - '"') >> '"'] + [_val = phx::bind(&WGrammar::add_literal, this, _1)]; + set = ('[' >> -char_('^') >> (*range | *(char_ - ']')) >> ']') + [_val = phx::bind(&WGrammar::add_char_range, this, _1, _2)]; + range = (char_ - ']') >> '-' >> (char_ - ']'); + group = ('(' >> omit[*space] >> nested_alternates >> omit[*space] >> ')') + [_val = phx::bind(&WGrammar::add_group, this, _1)]; + } + + Rules get_rules() { return std::move(rules); } + + void add_rule(uint32_t rule_id, Rule& rule) { + rule.push_back({WHISPER_GRETYPE_END, 0}); + if (rules.size() <= rule_id) { + rules.resize(rule_id + 1); + } + rules[rule_id] = std::move(rule); + } + + Rule merge(Rules& rules) { + Rule result; + for(auto r : rules) result.insert(result.end(), r.begin(), r.end()); + return result; + } + + Rule add_alternate(Rule& rule, Rules& rules) { + for (auto& r : rules) { + rule.push_back({WHISPER_GRETYPE_ALT, 0}); + rule.insert(rule.end(), r.begin(), r.end()); + } + return std::move(rule); + } + + uint32_t generate_symbol_id() { + uint32_t next_id = static_cast(symbol_ids.size()); + std::string id = "__AUTO_GEN__" + std::to_string(next_id); + symbol_ids[id] = next_id; + return next_id; + } + + uint32_t get_symbol_id(char begin, std::vector& id) { + id.insert(id.begin(), begin); + uint32_t next_id = static_cast(symbol_ids.size()); + auto result = symbol_ids.emplace(std::string(id.begin(), id.end()), next_id); + // fprintf(stderr, "%s: id{%d} => name{%s} \n", __func__, int(result.first->second), result.first->first.c_str()); + return result.first->second; + } + + Rule add_repetition(Rule& rule, boost::optional op) { + if (!op) return std::move(rule); + // apply transformation to previous symbol according to + // rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + auto rule_id = generate_symbol_id(); + Rule auto_rule = rule; + if (*op == '*' || *op == '+') { + // cause generated rule to recurse + auto_rule.push_back({WHISPER_GRETYPE_RULE_REF, rule_id}); + } + // mark start of alternate def + auto_rule.push_back({WHISPER_GRETYPE_ALT, 0}); + if (*op == '+') { + // add preceding symbol as alternate only for '+' (otherwise empty) + auto_rule.insert(auto_rule.end(), rule.begin(), rule.end()); + } + add_rule(rule_id, auto_rule); + return Rule(1, {WHISPER_GRETYPE_RULE_REF, rule_id}); + } + + Rule add_rule_ref(uint32_t rule_id) { + return Rule(1, {WHISPER_GRETYPE_RULE_REF, rule_id}); + } + + uint32_t add_group(Rule& rule) { + auto rule_id = generate_symbol_id(); + add_rule(rule_id, rule); + return rule_id; + } + + Rule add_literal(std::vector& str) { + Rule result; + std::string tmp(str.begin(), str.end()); + const char * pos = str.data(); + const char* end = pos + str.size(); + uint32_t value = 0; + while(pos != end) { + std::tie(value, pos) = parse_char(pos); + result.push_back({WHISPER_GRETYPE_CHAR, value}); + } + return result; + } + + Rule add_char_range(boost::optional neg, boost::variant, std::vector >& content) { + Rule result; + auto type = neg ? WHISPER_GRETYPE_CHAR_NOT : WHISPER_GRETYPE_CHAR; + switch(content.which()) { + case 0: { + auto& vec = boost::get>(content); + for (auto& range : vec) { + assert(range.size() == 2); + result.push_back({type, (uint32_t)range[0]}); + result.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, (uint32_t)range[1]}); + type = WHISPER_GRETYPE_CHAR_ALT; + } + break; + } + case 1: { + auto& vec = boost::get>(content); + const char* pos = &vec[0]; + const char* end = pos + vec.size(); + uint32_t value = 0; + while(pos != end) { + std::tie(value, pos) = parse_char(pos); + result.push_back({type, value}); + type = WHISPER_GRETYPE_CHAR_ALT; + } + } + } + + return result; + } + // TODO: define rules for escape sequences + static std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); + } + // TODO: rule for hex values + static std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); + } + // TODO: check if this can be replaced with boost::u8_to_u32_iterator + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from whisper.cpp + static std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); + } + + Rules rules; + std::map symbol_ids; + qi::rule rule_name; + qi::rule group; + qi::rule range; + qi::rule alternates, value, literal, repetition, sequence, set, nested_alternates, nested_sequence; + qi::rule rule; + qi::rule grammar; +}; +} + +Rules grammar_parser::test_parse(const std::string& src) { + Skipper comment = '#' >> *(qi::char_ - qi::eol) >> (qi::eol|qi::eoi); + Rules result; + Iterator begin = src.begin(); + Iterator end = src.end(); + try { + boost::spirit::qi::phrase_parse(begin, end, WGrammar(), comment, result); + if (begin != end) throw std::runtime_error("Parsing failed."); + + } catch (std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + return {}; + } + return result; +} diff --git a/examples/test-grammar/CMakeLists.txt b/examples/test-grammar/CMakeLists.txt new file mode 100644 index 00000000000..8b85d9461b2 --- /dev/null +++ b/examples/test-grammar/CMakeLists.txt @@ -0,0 +1,10 @@ +add_executable(test-grammar + test-grammar.cpp + ) + +include(DefaultTargetOptions) + +target_link_libraries(test-grammar PRIVATE + whisper + common + ) \ No newline at end of file diff --git a/examples/test-grammar/test-grammar.cpp b/examples/test-grammar/test-grammar.cpp new file mode 100644 index 00000000000..023fb215698 --- /dev/null +++ b/examples/test-grammar/test-grammar.cpp @@ -0,0 +1,110 @@ +#include +#include +#include + +auto src = +R"( root ::= init " " (command | question) "." +prompt ::= init + +test ::= "你好" + +# leading space is very important! +init ::= " Ok Whisper, start listening for commands." #asdfda + +command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value | + "Increase " device " by " value | "Decrease " device " by " value | + "Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task | + "Remind me to " task " at " time | "Show me " device | "Hide " device + +question ::= "What is the " device " status?" | "What is the current " device " value?" | + "What is the " device " temperature?" | "What is the " device " humidity?" | + "What is the " device " power consumption?" | "What is the " device " battery level?" | + "What is the weather like today?" | "What is the forecast for tomorrow?" | + "What is the time?" | "What is my schedule for today?" | "What tasks do I have?" | + "What reminders do I have?" + +device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" | + "music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" | + "vacuum cleaner" + +value ::= [0-9]+ + +media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie" + +task ::= [a-zA-Z]+ (" " [a-zA-Z]+)? + +time ::= [0-9] [0-9]? ("am" | "pm")? +)"; + +using namespace grammar_parser; + +int main() { + auto state = parse(src); + auto rules1 = test_parse(src); + + if (state.rules.empty() != rules1.empty()) { + fprintf(stderr, "Parsing success differs {%d} != {%d}\n", state.rules.empty(), rules1.empty()); + exit(1); + } + if (rules1.empty()) return 0; + + // traverse grammar, BFS comparison + std::set visited, visited1; + std::queue ref, ref1; + ref.push(0); + ref1.push(0); + while(!ref.empty() || !ref1.empty()) { + if (ref.size() != ref1.size()) { + fprintf(stderr, "Current references differs {%d} != {%d}\n", (int)ref.size(), (int)ref1.size()); + exit(1); + } + uint32_t current = ref.front(); ref.pop(); + uint32_t current1 = ref1.front(); ref1.pop(); + bool check = visited.find(current) == visited.end(); + bool check1 = visited1.find(current1) == visited1.end(); + if (check != check1) { + fprintf(stderr, "Current node status differs {%d} != {%d}\n", check, check1); + exit(1); + } + if (check) { + visited.insert(current); + visited1.insert(current1); + if (state.rules[current].size() != rules1[current1].size()) { + fprintf(stderr, "Current rule size differs {%d, %d} != {%d, %d}\n", (int) current, (int)state.rules[current].size(), (int) current1, (int)rules1[current1].size()); + exit(1); + } + for (size_t i = 0; i < rules1[current1].size(); ++i) { + if (state.rules[current][i].type != rules1[current1][i].type) { + fprintf(stderr, "Current %ith rule element type differs {%d, %d} != {%d, %d}\n",(int)i, (int)current, (int)state.rules[current][i].type, (int) current1, (int)rules1[current][i].type); + exit(1); + } + if (rules1[current1][i].type == WHISPER_GRETYPE_RULE_REF) { + ref.push(state.rules[current][i].value); + ref1.push(rules1[current1][i].value); + continue; + } + if (state.rules[current][i].value != rules1[current1][i].value) { + fprintf(stderr, "Current %ith rule element value differs {%d, %d} != {%d, %d}\n",(int)i, (int)current, (int)state.rules[current][i].value, (int)current1, (int)rules1[current][i].value); + exit(1); + } + } + fprintf(stderr, "Rules {%d} and {%d} are identical. Sizes {%d} {%d}\n", (int) current, (int) current1, (int)state.rules[current].size(), (int)rules1[current1].size()); + } + if (ref1.empty() && visited1.size() != rules1.size()) { + for (uint32_t i = 0; i < state.rules.size(); ++i) { + if (visited.find(i) == visited.end()) { + ref.push(i); + break; + } + } + for (uint32_t i = 0; i < rules1.size(); ++i) { + if (visited1.find(i) == visited1.end()) { + ref1.push(i); + break; + } + } + } + } + fprintf(stderr, "Grammars are identical\n"); + return 0; +} \ No newline at end of file