Skip to content

Commit

Permalink
Merge b2264fc into ef0e02c
Browse files Browse the repository at this point in the history
  • Loading branch information
eliotwrobson authored Nov 7, 2023
2 parents ef0e02c + b2264fc commit 0205ada
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 17 deletions.
61 changes: 44 additions & 17 deletions automata/fa/dfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,28 +499,45 @@ def _minify(
If the input DFA is partial, then the result is also a partial DFA
"""

# First, assemble backmap and equivalence class data structure
eq_classes = PartitionRefinement(reachable_states)
refinement = eq_classes.refine(reachable_final_states)

final_states_id = (
refinement[0][0] if refinement else next(iter(eq_classes.get_set_ids()))
)

# Per input-symbol backmap (tgt -> origin states)
transition_back_map: Dict[str, Dict[DFAStateT, List[DFAStateT]]] = {
symbol: {end_state: list() for end_state in reachable_states}
symbol: {end_state: [] for end_state in reachable_states}
for symbol in input_symbols
}

trap_state = None

for start_state, path in transitions.items():
if start_state in reachable_states:
for symbol, end_state in path.items():
symbol_dict = transition_back_map[symbol]
# If statement here needed to ignore certain transitions
# when minifying a partial DFA.
if end_state in symbol_dict:
symbol_dict[end_state].append(start_state)
for symbol in input_symbols:
end_state = path.get(symbol)
if end_state is not None:
# for symbol, end_state in path.items():
symbol_dict = transition_back_map[symbol]
# If statement here needed to ignore certain transitions
# for non-reachable states
if end_state in symbol_dict:
symbol_dict[end_state].append(start_state)
else:
# Add trap state if needed
if trap_state is None:
trap_state = next(
x for x in count(-1, -1) if x not in reachable_states
)
for trap_symbol in input_symbols:
transition_back_map[trap_symbol][trap_state] = []

reachable_states.add(trap_state)

transition_back_map[symbol][trap_state].append(start_state)

# Set up equivalence class data structure
eq_classes = PartitionRefinement(reachable_states)
refinement = eq_classes.refine(reachable_final_states)

final_states_id = (
refinement[0][0] if refinement else next(iter(eq_classes.get_set_ids()))
)

origin_dicts = tuple(transition_back_map.values())
processing = {final_states_id}
Expand Down Expand Up @@ -558,7 +575,12 @@ def _minify(
)

# need a backmap to prevent constant calls to index
back_map = {state: name for name, eq in eq_class_name_pairs for state in eq}
back_map = {
state: name
for name, eq in eq_class_name_pairs
for state in eq
if trap_state not in eq
}

new_input_symbols = input_symbols
new_states = frozenset(back_map.values())
Expand All @@ -567,12 +589,17 @@ def _minify(
new_transitions = {}

for name, eq in eq_class_name_pairs:
# For trap state, can just leave out
if trap_state in eq:
continue

eq_class_rep = next(iter(eq))

inner_transition_dict_old = transitions[eq_class_rep]
new_transitions[name] = {
letter: back_map[inner_transition_dict_old[letter]]
for letter in inner_transition_dict_old.keys()
if inner_transition_dict_old[letter] in reachable_states
if inner_transition_dict_old[letter] in back_map.keys()
}

allow_partial = any(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_dfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,28 @@ def test_minify_partial_dfa(self) -> None:
self.assertEqual(len(minified_partial_dfa.states), 4)
self.assertEqual(minified_partial_dfa, partial_dfa_extra_state)

def test_minify_partial_dfa_correctness(self) -> None:
"""Test correctness of minifying partial DFAs"""
input_symbols = {"a", "b", "c"}
dfa = DFA.from_finite_language(
language={"ab", "abcb"}, input_symbols=input_symbols, as_partial=True
)

self.assertEqual(dfa.minify(), dfa)

dfa2 = DFA.from_finite_language(
language={"ab", "abba", "cbab"},
input_symbols=input_symbols,
as_partial=True,
)

self.assertEqual(dfa2.minify(), dfa2)

self.assertEqual(dfa.union(dfa2, minify=False), dfa.union(dfa2, minify=True))
self.assertEqual(dfa.intersection(dfa2, minify=False), dfa.intersection(dfa2, minify=True))
self.assertEqual(dfa.symmetric_difference(dfa2, minify=False), dfa.symmetric_difference(dfa2, minify=True))
self.assertEqual(dfa.difference(dfa2, minify=False), dfa.difference(dfa2, minify=True))

def test_init_nfa_simple(self) -> None:
"""Should convert to a DFA a simple NFA."""
nfa = NFA(
Expand Down

0 comments on commit 0205ada

Please sign in to comment.