# GenParse: Lark Interface

In [1]:
import lark

from genparse.util import LarkStuff
from arsenal import Integerizer
from arsenal.maths import compare
from collections import Counter

## Using lark as a front end

In [2]:
grammar2 = r"""
WS: /[ \t\f\r\n]/
STAR: "*"
NUMBER: /\d+/

start: WS "SELECT" WS select_expr WS "FROM" WS from_expr
  [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
EOS: "</s>"
select_expr: STAR | select_list
bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")" 
bool_expr: var "=" value | var ">" value | var "<" value
from_expr: "data"
orderby_expr: var_list WS "ASC" | var_list WS "DESC"
select_list: select_var ("," WS select_var)*
var_list: var ("," WS var)*
select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
value: NUMBER | "red" | "blue" | "white" | "black" | "latino" | "republican" | "democrat" | "male" | "female"

"""

In [3]:
grammar1 = """
start: query_expr EOS

EOS: "</s>"

query_expr: select [ "ORDER" "BY" (order_by_expr ",")*  order_by_expr] [ "LIMIT" integer_ ] 

select: "SELECT" [(select_expr ",")*] select_expr "FROM" "data" [ "WHERE" bool_expression ] [ "GROUP" "BY" [(expression ",")*] expression ]

select_expr.0: expression_math [ [ "AS" ] alias ] -> select_expression

?expression_math: expression_product
               | expression_math PLUS expression_product -> expression_add
               | expression_math "-" expression_product -> expression_sub
               | AGGREGATION expression_math /\)/ -> sql_aggregation

?expression: (name | STAR) -> column_name
            | literal

?expression_product: expression_parens
                  | expression_product STAR expression_parens
                  | expression_product "/" expression_parens 

?expression_parens: expression
                  | /\(/ expression_parens STAR expression /\)/ 
                  | /\(/  expression_parens "/" expression /\)/ 
                  | /\(/  expression_parens PLUS expression /\)/
                  | /\(/  expression_parens "-" expression /\)/

bool_expression: bool_parentheses
                 | bool_expression "AND" bool_parentheses 
                 | bool_expression "OR" bool_parentheses
bool_parentheses: comparison_type
                 | /\(/   bool_expression "AND" comparison_type /\)/
                 | /\(/  bool_expression "OR" comparison_type /\)/
comparison_type: equals | not_equals | greater_than | less_than | greater_than_or_equal
| less_than_or_equal | is_null | is_not_null
equals: expression_math "=" expression_math
not_equals: expression_math ("<>" | "!=") expression_math
greater_than: expression_math ">" expression_math
less_than: expression_math "<" expression_math
greater_than_or_equal: expression_math ">=" expression_math
less_than_or_equal: expression_math "<=" expression_math
is_null: expression_math "is" "null"
is_not_null: expression_math "is" "not" "null"

alias: /[A-Za-z]+/
name: /[A-Za-z]+/
PLUS: /\+/

order_by_expr: expression_math ["ASC"] -> order_asc
        | expression_math "DESC" -> order_desc

AGGREGATION.8: ("sum(" | "avg(" | "min(" | "max(" | "count(" "distinct" | "count(")
STAR: /\*/
integer_: /[1-9][0-9]*/
?literal: boolean -> bool
       | integer_ -> number
       | ESCAPED_STRING -> string

boolean: "true" -> true
       | "false" -> false

%import common.WS
%ignore WS
%import common.ESCAPED_STRING
    
"""

The following code is adapted from partenon.

In [4]:
raw_grammar = grammar1

lark_stuff = LarkStuff(raw_grammar)

In [5]:
intern = Integerizer()   # rename nonterminals to integers
g = lark_stuff.convert()
g = g.rename(intern)
assert g.in_cnf()    # lark returns a grammar in CNF
#g = g.cnf

In [6]:
g

In [7]:
len(g.rules), len(g.V), len(g.N)

(320, 38, 184)

In [8]:
sorted(g.cnf.V)

['AGGREGATION',
 'AND',
 'AS',
 'ASC',
 'BY',
 'COMMA',
 'DATA',
 'DESC',
 'EOS',
 'EQUAL',
 'ESCAPED_STRING',
 'FALSE',
 'FROM',
 'GROUP',
 'IS',
 'LESSTHAN',
 'LIMIT',
 'MINUS',
 'MORETHAN',
 'NOT',
 'NULL',
 'OR',
 'ORDER',
 'PLUS',
 'SELECT',
 'SLASH',
 'STAR',
 'TRUE',
 'WHERE',
 'WS',
 '__ANON_0',
 '__ANON_1',
 '__ANON_2',
 '__ANON_3',
 '__ANON_4',
 '__ANON_5',
 '__ANON_6',
 '__ANON_7']

In [9]:
#from newton.linking import LinkAnalysis
#f = Integerizer()
#links = LinkAnalysis(g.rename(f))
#links.dfs

In [10]:
g.language(6)

Counter({('SELECT', '__ANON_7', 'FROM', 'DATA', 'EOS'): 0.0034602076124567475,
         ('SELECT', 'STAR', 'FROM', 'DATA', 'EOS'): 0.0034602076124567475,
         ('SELECT', 'TRUE', 'FROM', 'DATA', 'EOS'): 0.0034602076124567475,
         ('SELECT', 'FALSE', 'FROM', 'DATA', 'EOS'): 0.0034602076124567475,
         ('SELECT',
          'ESCAPED_STRING',
          'FROM',
          'DATA',
          'EOS'): 0.0034602076124567475,
         ('SELECT', '__ANON_6', 'FROM', 'DATA', 'EOS'): 0.0034602076124567475,
         ('SELECT',
          '__ANON_7',
          'FROM',
          'DATA',
          'LIMIT',
          '__ANON_7',
          'EOS'): 0.0002883506343713956,
         ('SELECT',
          'STAR',
          'FROM',
          'DATA',
          'LIMIT',
          '__ANON_7',
          'EOS'): 0.0002883506343713956,
         ('SELECT',
          'TRUE',
          'FROM',
          'DATA',
          'LIMIT',
          '__ANON_7',
          'EOS'): 0.0002883506343713956,
         ('SELECT',

## Tokenization

We can extract lark's tokenizer in a format that we can build on.  We will even make a DIY tokenizer based on Python's `re` library.

| Terminology  |         |
|--------------|---------|
| tokenization | lexing  |
| tokenizers   | lexers  |
| tokens       | lexemes |



In [11]:
sorted(lark_stuff.all_terminals, key=lambda t: -t.priority)

[TerminalDef('AGGREGATION', '(?:count\\(distinct|count\\(|sum\\(|avg\\(|min\\(|max\\()'),
 TerminalDef('ESCAPED_STRING', '".*?(?<!\\\\)(\\\\\\\\)*?"'),
 TerminalDef('WS', '(?:[ \t\x0c\r\n])+'),
 TerminalDef('EOS', '</s>'),
 TerminalDef('PLUS', '\\+'),
 TerminalDef('STAR', '\\*'),
 TerminalDef('COMMA', ','),
 TerminalDef('ORDER', 'ORDER'),
 TerminalDef('BY', 'BY'),
 TerminalDef('LIMIT', 'LIMIT'),
 TerminalDef('WHERE', 'WHERE'),
 TerminalDef('GROUP', 'GROUP'),
 TerminalDef('SELECT', 'SELECT'),
 TerminalDef('FROM', 'FROM'),
 TerminalDef('DATA', 'data'),
 TerminalDef('AS', 'AS'),
 TerminalDef('MINUS', '\\-'),
 TerminalDef('__ANON_0', '\\)'),
 TerminalDef('SLASH', '/'),
 TerminalDef('__ANON_1', '\\('),
 TerminalDef('AND', 'AND'),
 TerminalDef('OR', 'OR'),
 TerminalDef('EQUAL', '='),
 TerminalDef('__ANON_2', '<>'),
 TerminalDef('__ANON_3', '!='),
 TerminalDef('MORETHAN', '>'),
 TerminalDef('LESSTHAN', '<'),
 TerminalDef('__ANON_4', '>='),
 TerminalDef('__ANON_5', '<='),
 TerminalDef('IS', 'i

### DIY tokenizer

In [12]:
text = "12 + 24 - 36 * 48 / 60 SELECT table.name AS thing WHERE table.potato IS NOT 'banana'"

for x, y in lark_stuff.simple_tokenizer(text):
    print(f'{x:15s} -> {y!r}')

__ANON_7        -> '12'
PLUS            -> '+'
__ANON_7        -> '24'
MINUS           -> '-'
__ANON_7        -> '36'
STAR            -> '*'
__ANON_7        -> '48'
SLASH           -> '/'
__ANON_7        -> '60'
SELECT          -> 'SELECT'
__ANON_6        -> 'table'
__ANON_6        -> 'name'
AS              -> 'AS'
__ANON_6        -> 'thing'
WHERE           -> 'WHERE'
__ANON_6        -> 'table'
__ANON_6        -> 'potato'
__ANON_6        -> 'IS'
__ANON_6        -> 'NOT'
__ANON_6        -> 'banana'


### Parsing tokenized input

In [13]:
text = 'SELECT name FROM data </s>'

In [14]:
tokens = list(lark_stuff.lex(text))
tokens

[Token('SELECT', 'SELECT'),
 Token('__ANON_6', 'name'),
 Token('FROM', 'FROM'),
 Token('DATA', 'data'),
 Token('EOS', '</s>')]

Call the lark parser on the text:

In [15]:
lark_stuff.instance.parse(text)

Tree(Token('RULE', 'start'), [Tree(Token('RULE', 'query_expr'), [Tree(Token('RULE', 'select'), [Tree('select_expression', [Tree('column_name', [Tree(Token('RULE', 'name'), [Token('__ANON_6', 'name')])]), None]), None, None]), None, None]), Token('EOS', '</s>')])

We can call the lark parser on these tokens:

In [16]:
lark_stuff.parser.parse(tokens, 'start')

Tree(NonTerminal(Token('RULE', 'start')), [Tree(NonTerminal(Token('RULE', 'query_expr')), [Tree(NonTerminal(Token('RULE', 'select')), [Token('SELECT', 'SELECT'), Tree(NonTerminal(Token('RULE', 'select_expr')), [Tree(NonTerminal(Token('RULE', 'expression_math')), [Tree(NonTerminal(Token('RULE', 'expression_product')), [Tree(NonTerminal(Token('RULE', 'expression_parens')), [Tree(NonTerminal(Token('RULE', 'expression')), [Tree(NonTerminal(Token('RULE', 'name')), [Token('__ANON_6', 'name')])])])])])]), Token('FROM', 'FROM'), Token('DATA', 'data')])]), Token('EOS', '</s>')])

We can call our parser on this text to get its total weight

In [17]:
g([t.type for t in tokens])

0.0034602076124567475

### Tokenizer State Machines

**TODO**: Dive into the [greenery](https://github.com/qntm/greenery) FSMs to figure out how they work; we can work backward from the Partenon tensor building routine.  Another option might be the [interegular](https://github.com/MegaIng/interegular) library.


**Note**: Tokenizers are FSTs, not FSAs.  However, these libraries implement the kind of restricted FSTs with a separate FSA per token type.

In [18]:
import greenery

def make_greenery_fsms(
    regex_list, 
    ignore = ("\s*",), 
    chars = tuple(chr(i) for i in range(128))
):
    match ignore:
        case []:
            ignore_regex = ""
        case [ignore]:
            ignore_regex = ignore
        case _:
            raise ValueError("ignore must be a list of length at most 1")

    patterns = []
    for regex in regex_list:
        # greenery does not escape spaces
        regex = regex.replace("\\ ", " ")
        patterns.append(greenery.parse(regex + ignore_regex))

    return [pattern.to_fsm() for pattern in patterns]

In [19]:
fsms = make_greenery_fsms([t for t in tokens])

In [20]:
fsms[0]

Fsm(alphabet=frozenset({Charclass((('E', 'E'),)), Charclass((('T', 'T'),)), Charclass((('!', 'B'), ('D', 'D'), ('F', 'K'), ('M', 'R'))), ~Charclass((('\t', 'T'),)), Charclass((('L', 'L'),)), Charclass((('\t', '\r'), (' ', ' '))), Charclass((('\x0e', '\x1f'),)), Charclass((('S', 'S'),)), Charclass((('C', 'C'),))}), states=frozenset({0, 1, 2, 3, 4, 5, 6, 7}), initial=0, finals=frozenset({7}), map={0: {Charclass((('\t', '\r'), (' ', ' '))): 1, Charclass((('\x0e', '\x1f'),)): 1, Charclass((('!', 'B'), ('D', 'D'), ('F', 'K'), ('M', 'R'))): 1, Charclass((('C', 'C'),)): 1, Charclass((('E', 'E'),)): 1, Charclass((('L', 'L'),)): 1, Charclass((('S', 'S'),)): 2, Charclass((('T', 'T'),)): 1, ~Charclass((('\t', 'T'),)): 1}, 1: {Charclass((('\t', '\r'), (' ', ' '))): 1, Charclass((('\x0e', '\x1f'),)): 1, Charclass((('!', 'B'), ('D', 'D'), ('F', 'K'), ('M', 'R'))): 1, Charclass((('C', 'C'),)): 1, Charclass((('E', 'E'),)): 1, Charclass((('L', 'L'),)): 1, Charclass((('S', 'S'),)): 1, Charclass((('T', '

The code below was used in [partenon](https://github.com/probcomp/partenon) to convert the the lark tokenizer `->` greenery FSMs `->` tensors.  We can repurpose it to convert into a transducer from characters to token names.

In [21]:
def make_matrix_from_fsms(fsms, chars):
    max_states = max([len(fsm.states) for fsm in fsms])
    n_inputs = len(chars)

    m_shape = (len(fsms), max_states, n_inputs, max_states)
    m = np.zeros(m_shape, dtype=np.int8)
    finals = jnp.zeros((len(fsms), max_states))
    for fsm_idx, fsm in enumerate(fsms):
        for final_state in fsm.finals:
            finals = finals.at[fsm_idx, final_state].set(1)
        fsm_states = fsm.states
        rejection_states = [e for e in fsm_states if not fsm.islive(e)]
        for state in fsm_states:
            arcs = fsm.map[state]
            for input_char, next_state in arcs.items():
                if next_state in rejection_states:  # rejection state
                    continue
                for char in input_char.get_chars():
                    input_idx = chars.index(char)
                    m[fsm_idx, state, input_idx, next_state] = 1

    m = jnp.array(m)

    return m, finals, max_states

## Prefix Grammar

In [39]:
len(g.cnf.rules), len(g.cnf.prefix_grammar.trim().rules), len(g.cnf.prefix_grammar.trim().rules)/len(g.cnf.rules)

(304, 1220, 4.0131578947368425)

In [44]:
#N = G.nullaryremove()    # could be faster with SCC-based prioritization

In [45]:
g.cnf.prefix_grammar.trim().cnf

In [46]:
g.prefix_weight([t.type for t in tokens])

0.0034602076124567475