diff --git a/automata/base/automaton.py b/automata/base/automaton.py index 5929ce3e..2675fe6c 100644 --- a/automata/base/automaton.py +++ b/automata/base/automaton.py @@ -151,6 +151,10 @@ def __repr__(self) -> str: ) return f"{self.__class__.__qualname__}({values})" - def __contains__(self, input_str: str) -> bool: + def __contains__(self, item: Any) -> bool: """Returns whether the word is accepted by the automaton.""" - return self.accepts_input(input_str) + + if not isinstance(item, str): + return False + + return self.accepts_input(item) diff --git a/tests/test_dfa.py b/tests/test_dfa.py index 8677a165..976af89e 100644 --- a/tests/test_dfa.py +++ b/tests/test_dfa.py @@ -178,6 +178,7 @@ def test_accepts_input_false(self) -> None: """Should return False if DFA input is rejected.""" self.assertFalse(self.dfa.accepts_input("011")) self.assertNotIn("011", self.dfa) + self.assertNotIn(1, self.nfa) def test_read_input_step(self) -> None: """Should return validation generator if step flag is supplied.""" diff --git a/tests/test_nfa.py b/tests/test_nfa.py index d21bd24c..951c7230 100644 --- a/tests/test_nfa.py +++ b/tests/test_nfa.py @@ -233,6 +233,7 @@ def test_accepts_input_false(self) -> None: """Should return False if NFA input is rejected.""" self.assertFalse(self.nfa.accepts_input("abba")) self.assertNotIn("abba", self.nfa) + self.assertNotIn(1, self.nfa) def test_cyclic_lambda_transitions(self) -> None: """Should traverse NFA containing cyclic lambda transitions."""