/
parsing.py
461 lines (370 loc) · 16.1 KB
/
parsing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
from collections import ChainMap, defaultdict
from copy import copy
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Dict,
Iterable,
Optional,
Set,
Tuple,
)
import interegular
import regex
from interegular.fsm import FSM, anything_else
from interegular.patterns import Unsupported
from lark import Lark, Token
from lark.exceptions import (
LexError,
UnexpectedCharacters,
UnexpectedEOF,
UnexpectedToken,
)
from lark.indenter import PythonIndenter
from lark.lexer import BasicLexer, ContextualLexer, LexerState, Scanner
from lark.parsers.lalr_analysis import Shift
from lark.parsers.lalr_interactive_parser import InteractiveParser
from lark.parsers.lalr_parser import ParseConf, ParserState
from lark.utils import get_regexp_width
if TYPE_CHECKING:
from lark.lexer import LexerThread
PartialParseState = Tuple[str, int]
class PartialTokenEOF(UnexpectedEOF):
pass
class PartialScanner(Scanner):
def __init__(self, scanner: Scanner):
self.terminals = scanner.terminals
self.g_regex_flags = scanner.g_regex_flags
self.re_ = regex
self.use_bytes = scanner.use_bytes
self.match_whole = scanner.match_whole
self.allowed_types = scanner.allowed_types
self._mres = scanner._mres
def match(self, text, pos) -> Optional[Tuple[str, Optional[str], bool]]:
for mre in self._mres:
m = mre.match(text, pos=pos, partial=True)
if m: # and ((not m.partial) or m.endpos == len(text)):
return m.group(0), m.lastgroup, m.partial
return None
class PartialBasicLexer(BasicLexer):
def __init__(self, basic_lexer: BasicLexer):
self.re = regex
self.newline_types = basic_lexer.newline_types
self.ignore_types = basic_lexer.ignore_types
self.terminals = basic_lexer.terminals
self.user_callbacks = basic_lexer.user_callbacks
self.g_regex_flags = basic_lexer.g_regex_flags
self.use_bytes = basic_lexer.use_bytes
self.terminals_by_name = basic_lexer.terminals_by_name
self.callback = getattr(basic_lexer, "callback", None)
if basic_lexer._scanner is not None:
self._scanner: Optional[PartialScanner] = PartialScanner(
basic_lexer._scanner
)
else:
self._scanner = None
# This is used to determine the token type for partial matches
self.terminal_to_regex = {}
for name, terminal in self.terminals_by_name.items():
self.terminal_to_regex[name] = self.re.compile(
terminal.pattern.to_regexp(), self.g_regex_flags
)
def _build_scanner(self):
super()._build_scanner()
self._scanner = PartialScanner(self._scanner)
def partial_matches(self, value, type_):
partial_matches = set()
# TODO: It's unfortunate that we have to do this costly search (again).
# It would be better if we could *not* short-circuit the first time we
# scan in the call to `self.match`.
for term_name, term_regex in self.terminal_to_regex.items():
if term_name == type_:
# A standard lexed token result could actual indicate a partial
# match
regex_min, regex_max = get_regexp_width(term_regex.pattern)
if regex_min <= len(value) < regex_max:
partial_matches.add(term_name)
else:
m = term_regex.match(value, partial=True)
if m:
partial_matches.add(term_name)
return partial_matches
def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token:
line_ctr = lex_state.line_ctr
while line_ctr.char_pos < len(lex_state.text):
res = self.match(lex_state.text, line_ctr.char_pos)
if not res:
allowed = self.scanner.allowed_types - self.ignore_types
if not allowed:
allowed = {"<END-OF-FILE>"}
raise UnexpectedCharacters(
lex_state.text,
line_ctr.char_pos,
line_ctr.line,
line_ctr.column,
allowed=allowed,
token_history=lex_state.last_token and [lex_state.last_token],
state=parser_state,
terminals_by_name=self.terminals_by_name,
)
value, type_, partial = res
# Don't advance the lexing state if we're at the end; there could
# be ambiguous token types that aren't finished.
if line_ctr.char_pos + len(value) >= len(lex_state.text):
partial_matches = self.partial_matches(value, type_)
if partial_matches or partial:
raise PartialTokenEOF(partial_matches)
assert isinstance(self.callback, Dict)
if type_ not in self.ignore_types:
t = Token(
type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column
)
line_ctr.feed(value, type_ in self.newline_types)
t.end_line = line_ctr.line
t.end_column = line_ctr.column
t.end_pos = line_ctr.char_pos
if t.type in self.callback:
t = self.callback[t.type](t)
if not isinstance(t, Token):
raise LexError(
"Callbacks must return a token (returned %r)" % t
)
lex_state.last_token = t
return t
if type_ in self.callback:
t2 = Token(
type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column
)
self.callback[type_](t2)
line_ctr.feed(value, type_ in self.newline_types)
raise EOFError(self)
class PartialPythonIndenter(PythonIndenter):
"""An `Indenter` that doesn't reset its state every time `process` is called."""
def process(self, stream):
return self._process(stream)
def _process(self, stream):
for token in stream:
# These were previously *after* the `yield`, but that makes the
# state tracking unnecessarily convoluted.
if token.type in self.OPEN_PAREN_types:
self.paren_level += 1
elif token.type in self.CLOSE_PAREN_types:
self.paren_level -= 1
if self.paren_level < 0:
raise UnexpectedToken(token, [])
if token.type == self.NL_type:
yield from self.handle_NL(token)
else:
yield token
# while len(self.indent_level) > 1:
# self.indent_level.pop()
# yield Token(self.DEDENT_type, "")
def __copy__(self):
res = type(self)()
res.paren_level = self.paren_level
res.indent_level = copy(self.indent_level)
return res
def copy_lexer_thread(lexer_thread: "LexerThread") -> "LexerThread":
res = copy(lexer_thread)
res.lexer = copy(res.lexer)
if getattr(res.lexer, "postlexer", None):
if isinstance(res.lexer.postlexer, PythonIndenter) and not isinstance(
res.lexer.postlexer, PartialPythonIndenter
):
# Patch these methods so that the post lexer keeps its state
# XXX: This won't really work in generality.
postlexer = PartialPythonIndenter()
postlexer.paren_level = res.lexer.postlexer.paren_level
postlexer.indent_level = res.lexer.postlexer.indent_level
res.lexer.postlexer = postlexer
else:
res.lexer.postlexer = copy(res.lexer.postlexer)
# Patch/replace the lexer objects so that they support partial matches
context_lexer = res.lexer
if not isinstance(context_lexer, ContextualLexer):
# XXX: The layouts change with the grammars
context_lexer = context_lexer.lexer
assert isinstance(context_lexer, ContextualLexer)
if not isinstance(context_lexer.root_lexer, PartialBasicLexer):
context_lexer.root_lexer = PartialBasicLexer(context_lexer.root_lexer)
basic_lexers = context_lexer.lexers
for idx, lexer in basic_lexers.items():
basic_lexers[idx] = PartialBasicLexer(lexer)
return res
def copy_parser_state(parser_state: ParserState) -> ParserState:
res = copy(parser_state)
res.lexer = copy_lexer_thread(res.lexer)
return res
def copy_ip(ip: "InteractiveParser") -> "InteractiveParser":
res = copy(ip)
res.lexer_thread = copy_lexer_thread(res.lexer_thread)
return res
def parse_to_end(parser_state: ParserState) -> Tuple[ParserState, Set[str]]:
"""Continue parsing from the current parse state and return partial next tokens.
.. warning::
The parse state `parser_state` is updated in-place and must be patched
to work with this function. Either patch it manually or use
`copy_parser_state` before calling this.
"""
expected_next_tokens: Set[str] = set()
try:
for token in parser_state.lexer.lex(parser_state):
parser_state.feed_token(token)
except PartialTokenEOF as e:
expected_next_tokens = e.expected
return parser_state, expected_next_tokens
def find_partial_matches(
fsm: FSM, input_string: str, start_state: Optional[int] = None
) -> Set[Tuple[Optional[int], Tuple[int, ...]]]:
"""Find the states in the finite state machine `fsm` that accept `input_string`.
This will consider all possible states in the finite state machine (FSM)
that accept the beginning of `input_string` as starting points, unless a
specific `start_state` is provided.
Parameters
----------
fsm
The finite state machine.
input_string
The string for which we generate partial matches.
start_state
A single fixed starting state to consider. For example, if this value
is set to `fsm.initial`, it attempt to read `input_string` from the
beginning of the FSM/regular expression.
Returns
-------
A set of tuples corresponding to each valid starting state in the FSM. The
first element of each tuple contains either ``None`` or an integer
indicating the position in `input_string` at which the FSM terminated. The
second element is the tuple of states visited during execution of the FSM
plus the next, unvisited transition state.
"""
if len(input_string) == 0 or input_string[0] not in fsm.alphabet:
return set()
trans_key = fsm.alphabet[input_string[0]]
# TODO: We could probably reuse parts of the computed paths when computing
# results for multiple starting points.
def _partial_match(
trans: Dict[int, int]
) -> Tuple[Optional[int], Optional[Tuple[int, ...]]]:
fsm_map = ChainMap({fsm.initial: trans}, fsm.map)
state = fsm.initial
accepted_states: Tuple[int, ...] = ()
for i, symbol in enumerate(input_string):
if anything_else in fsm.alphabet and symbol not in fsm.alphabet:
symbol = anything_else
trans_key = fsm.alphabet[symbol]
if not (state in fsm_map and trans_key in fsm_map[state]):
if state in fsm.finals:
i -= 1
break
return None, None
state = fsm_map[state][trans_key]
accepted_states += (state,)
terminated = state in fsm.finals
if not terminated and state == fsm.initial:
return None, None
return None if not terminated else i, accepted_states
res = set()
transition_maps = (
fsm.map if start_state is None else {start_state: fsm.map[start_state]}
)
for state, trans in transition_maps.items():
if trans_key in trans:
n_matched, path = _partial_match(trans)
if path is not None:
res.add((n_matched, (state,) + path))
return res
def terminals_to_fsms(lp: Lark) -> Dict[str, FSM]:
"""Construct a ``dict`` mapping terminal symbol names to their finite state machines."""
symbol_names_and_fsms = {}
for terminal in lp.terminals:
pattern = interegular.parse_pattern(terminal.pattern.to_regexp())
# TODO: Use `pyparser.terminals[0].pattern.flags`?
try:
fsm = pattern.to_fsm().reduce()
except Unsupported:
fsm = None
symbol_names_and_fsms[terminal.name] = fsm
return symbol_names_and_fsms
def map_partial_states_to_vocab(
vocabulary: Iterable[str],
terminals_to_fsms_map: Dict[str, FSM],
partial_match_filter: Callable[
[str, Optional[int], Tuple[int, ...]], bool
] = lambda *args: True,
final_state_string: Optional[str] = None,
) -> Tuple[
DefaultDict[PartialParseState, Set[int]], Dict[str, DefaultDict[int, Set[int]]]
]:
"""Construct a map from partial parse states to subsets of `vocabulary`.
The subsets of `vocabulary` consist of elements that are accepted by--or
transition to--the corresponding partial parse states.
Parameters
----------
vocabulary
The vocabulary composed of strings.
terminals_to_fsms_map
Terminal symbol names mapped to FSMs, as provided by `terminals_to_fsms`.
partial_match_filter
A callable that determines which partial matches to keep. The first
argument is the string being match, the rest are the unpacked partial
match return values of `find_partial_matches`.
final_state_string
A string from `vocabulary` that is to be added to all the final states
in the FSM.
"""
final_state_string_idx = None
# Partial parse states to the subsets of the vocabulary that accept them
pstate_to_vocab = defaultdict(set)
possible_paths = {}
for symbol_name, fsm in terminals_to_fsms_map.items():
terminal_possible_paths = defaultdict(set)
for i, vocab_string in enumerate(vocabulary):
if vocab_string == final_state_string:
final_state_string_idx = i
for end_idx, state_seq in find_partial_matches(fsm, vocab_string):
if partial_match_filter(vocab_string, end_idx, state_seq):
terminal_possible_paths[state_seq[0]].add(state_seq[-1])
pstate_to_vocab[(symbol_name, state_seq[0])].add(i)
possible_paths[symbol_name] = terminal_possible_paths
if final_state_string_idx is not None:
# Allow transitions to EOS from all terminals FSM states
for symbol_name, fsm in terminals_to_fsms_map.items():
for state in fsm.finals:
pstate_to_vocab[(symbol_name, state)].add(final_state_string_idx)
return pstate_to_vocab, possible_paths
def terminals_to_lalr_states(lp: Lark) -> DefaultDict[str, Set[int]]:
terminals_to_states = defaultdict(set)
parse_table = lp.parser.parser.parser.parse_table
for state, tokens_to_ops in parse_table.states.items():
for token, op in tokens_to_ops.items():
if op[0] == Shift:
# `op[1]` is the state we shift to when `token` is observed
terminals_to_states[token].add(op[1])
return terminals_to_states
def create_pmatch_parser_states(
lp: Lark,
terminals_to_states: Dict[str, Set[int]],
term_type: str,
ptoken: str,
pmatch: Tuple[int, Tuple[int, ...]],
) -> Tuple[ParserState, ...]:
parse_table = lp.parser.parser.parser.parse_table
# TODO: We need to effectively disable the callbacks that build the
# trees, because we aren't actually parsing a valid state that can, say,
# be reduced
def noop(*args, **kwargs):
pass
callbacks = {rule: noop for rule, cb in lp._callbacks.items()}
parse_conf = ParseConf(parse_table, callbacks, lp.options.start[0])
lexer_thread = lp.parser._make_lexer_thread(ptoken)
lexer_state = lexer_thread.state
lexer_state.line_ctr.char_pos = pmatch[0] + 1
lexer_state.last_token = Token(term_type, "")
res = tuple(
ParserState(parse_conf, lexer_thread, [state], None)
for state in terminals_to_states[term_type]
)
return res