In [14]:
from collections import defaultdict

class Node:
    def __init__(self, word, id):
        self.word = word
        self.id = id
        self.count = 0
        self.children = {}

class PatternExtractor:
    def __init__(self):
        self.preceding_tree = Node('<ROOT>', 0)
        self.following_tree = Node('<ROOT>', 1)
        self.node_counter = 2
        self.word_to_ids = defaultdict(list)
        self.id_to_node = {0: self.preceding_tree, 1: self.following_tree}
        self.child_to_parents = defaultdict(list)

    def add_start_end_flags_lower(self, sentences):
        return [f"<START> {sentence.lower()} <END>" for sentence in sentences]

    def get_or_create_node(self, current_tree, word, parent_id=None):
        if word not in current_tree.children:
            new_node = Node(word, self.node_counter)
            current_tree.children[word] = new_node
            self.word_to_ids[word].append(self.node_counter)
            self.id_to_node[self.node_counter] = new_node
            if parent_id is not None and parent_id not in self.child_to_parents[self.node_counter]:
                self.child_to_parents[self.node_counter].append(parent_id)
            self.node_counter += 1
        else:
            child_id = current_tree.children[word].id
            if parent_id is not None and parent_id not in self.child_to_parents[child_id]:
                self.child_to_parents[child_id].append(parent_id)
        return current_tree.children[word]

    def add_to_tree(self, words, direction, count=1):
        current_tree = self.preceding_tree if direction == 'preceding' else self.following_tree
        if direction == 'preceding':
            word_range = range(len(words) - 1, -1, -1)
        else:
            word_range = range(len(words))

        parent_id = current_tree.id

        for i in word_range:
            current_word = words[i]
            current_tree = self.get_or_create_node(current_tree, current_word, parent_id)
            current_tree.count += count
            parent_id = current_tree.id

    def create_tree_mask_as_root(self, sentences_dict):
        for key_word, sentences in sentences_dict.items():
            sentences_with_flags = self.add_start_end_flags_lower(sentences)

            for sentence in sentences_with_flags:
                words = sentence.split()
                key_word_index = words.index(key_word.lower())
                words_before = words[:key_word_index]
                words_after = words[key_word_index + 1:]

                self.add_to_tree(words_before, 'preceding')
                self.add_to_tree(words_after, 'following')

    def print_tree(self, node, level=0):
        print('  ' * level + f"{node.word} (count: {node.count}, id: {node.id})")
        for child in node.children.values():
            self.print_tree(child, level + 1)

    def print_trees(self):
        print("Preceding Tree (before <MASK>):")
        self.print_tree(self.preceding_tree)
        print("\nFollowing Tree (after <MASK>):")
        self.print_tree(self.following_tree)

    def get_nodes_by_word(self, word):
        ids = self.word_to_ids.get(word, [])
        return {self.id_to_node[id] for id in ids}

    def get_node_by_id(self, id):
        return self.id_to_node.get(id)

    def get_parents_by_id(self, id):
        parent_ids = self.child_to_parents.get(id, [])
        return [self.id_to_node[parent_id] for parent_id in parent_ids]

    def optimize_tree(self, word):
        all_nodes = self.get_nodes_by_word(word)
        groups_with_same_children = defaultdict(list)

        # Group nodes with identical children structures (use frozenset of (child_word, child_id))
        for node in all_nodes:
            child_structure = frozenset((child.word, child.id) for child in node.children.values())
            groups_with_same_children[child_structure].append(node)

        # Process node groups that have identical children
        for node_group in [i for i in groups_with_same_children.values() if len(i) > 1]:
            node_with_smallest_id = min(node_group, key=lambda node: node.id)

            # Remove other nodes and merge children into the node with the smallest ID
            parents = []
            for node in node_group:
                if node != node_with_smallest_id:
                    parents.extend(self.get_parents_by_id(node.id))
                    node_with_smallest_id.children.update(node.children)
                    node_group.remove(node)

            # Update all parent nodes to reference the merged node
            for parent_node in parents:
                parent_node.children[word] = node_with_smallest_id

            # Recursively optimize the parent nodes
            self.optimize_tree(parent_node.word)

# Example usage
sentences_dict = {
    'erinnere': [
        'Ich erinnere mich gut',
        'ich erinnere mich nicht',
        'nochmal erinnere ich mich nicht'
    ],
    'erinnert': [
        'wie erinnert man sich nochmal',
        'wo erinnert man sich nochmal',
        'vielleicht erinnert man sich dann nochmal'
    ]
}

# Create a PatternExtractor instance
extractor = PatternExtractor()
extractor.create_tree_mask_as_root(sentences_dict)
extractor.optimize_tree('<START>')
extractor.optimize_tree('<END>')
print("\nOptimized Trees:")
extractor.print_trees()



Optimized Trees:
Preceding Tree (before <MASK>):
<ROOT> (count: 0, id: 0)
  ich (count: 2, id: 2)
    <START> (count: 2, id: 3)
  nochmal (count: 1, id: 9)
    <START> (count: 2, id: 3)
  wie (count: 1, id: 15)
    <START> (count: 1, id: 16)
  wo (count: 1, id: 21)
    <START> (count: 2, id: 3)
  vielleicht (count: 1, id: 23)
    <START> (count: 2, id: 3)

Following Tree (after <MASK>):
<ROOT> (count: 0, id: 1)
  mich (count: 2, id: 4)
    gut (count: 1, id: 5)
      <END> (count: 1, id: 6)
    nicht (count: 1, id: 7)
      <END> (count: 1, id: 6)
  ich (count: 1, id: 11)
    mich (count: 1, id: 12)
      nicht (count: 1, id: 7)
        <END> (count: 1, id: 6)
  man (count: 3, id: 17)
    sich (count: 3, id: 18)
      nochmal (count: 2, id: 19)
        <END> (count: 2, id: 20)
      dann (count: 1, id: 25)
        nochmal (count: 1, id: 26)
          <END> (count: 1, id: 6)
