In [5]:
from collections import defaultdict

class PatternExtractor:
    def __init__(self):
        # Separate trees for preceding and following words relative to <MASK>
        self.preceding_tree = defaultdict(dict)
        self.following_tree = defaultdict(dict)
        self.node_counter = 0  # Unique node ID counter

        # Add special <BACKWARD> and <FORWARD> nodes to split graph
        self.backward_root_id = self.node_counter
        self.node_counter += 1
        self.forward_root_id = self.node_counter
        self.node_counter += 1


    # Function to add start and end flags and convert sentences to lowercase
    def add_start_end_flags_lower(self, sentences):
        return [f"<START> {sentence.lower()} <END>" for sentence in sentences]

    # Function to build the tree structure for preceding words and assign unique IDs
    def add_to_tree(self, words, direction, count=1):
        current_tree = self.preceding_tree if direction == 'preceding' else self.following_tree
        if direction == 'preceding':
            word_range = range(len(words) - 1, -1, -1)
        else:
            word_range = range(len(words))
        for i in word_range:
            current_word = words[i]
            if current_word not in current_tree:
                current_tree[current_word] = {'count': 0, 'children': defaultdict(dict), 'id': self.node_counter}
                self.node_counter += 1
            current_tree[current_word]['count'] += count
            current_tree = current_tree[current_word]['children']

    # Function to create the trees with separate handling for preceding and following words
    def create_tree_mask_as_root(self, sentences_dict):
        for key_word, sentences in sentences_dict.items():
            # Add start and end flags and convert to lowercase
            sentences_with_flags = self.add_start_end_flags_lower(sentences)

            for sentence in sentences_with_flags:
                # Split the sentence into words
                words = sentence.split()

                # Find the position of the key word and split the sentence into two parts
                key_word_index = words.index(key_word.lower())
                words_before = words[:key_word_index]  # words before <MASK>
                words_after = words[key_word_index + 1:]  # words after <MASK>

                # Add to the respective trees
                self.add_to_tree(words_before, 'preceding')
                self.add_to_tree(words_after, 'following')


    # Function to update the tree when nodes are merged
    def _update_tree_node_id(self, word, old_node_id, new_node_id):
        # Update the preceding tree
        self._update_tree_node_id_recursive(self.preceding_tree, word, old_node_id, new_node_id)
        # Update the following tree
        self._update_tree_node_id_recursive(self.following_tree, word, old_node_id, new_node_id)

    # Helper function to update node IDs in a tree recursively
    def _update_tree_node_id_recursive(self, tree, word, old_node_id, new_node_id):
        for current_word, data in tree.items():
            if data['id'] == old_node_id:
                data['id'] = new_node_id  # Replace old node ID with new node ID
            self._update_tree_node_id_recursive(data['children'], word, old_node_id, new_node_id)

    # Function to print the trees in a readable format
    def print_tree(self, tree, level=0):
        for word, data in tree.items():
            print('  ' * level + f"{word} (count: {data['count']}, id: {data['id']})")
            self.print_tree(data['children'], level + 1)

    # Function to print both the preceding and following trees
    def print_trees(self):
        print("Preceding Tree (before <MASK>):")
        self.print_tree(self.preceding_tree)
        print("\nFollowing Tree (after <MASK>):")
        self.print_tree(self.following_tree)


# Example usage
sentences_dict = {
    'erinnere': [
        'Ich erinnere mich gut',
        'ich erinnere mich nicht',
        'nochmal erinnere ich mich nicht'
    ],
    'erinnert': [
        'wie erinnert man sich nochmal',
        'wo erinnert man sich nochmal',
        'vielleicht erinnert man sich dann nochmal'
    ]
}
# Create a PatternExtractor instance
extractor = PatternExtractor()

# Build the trees from the sentences
extractor.create_tree_mask_as_root(sentences_dict)

# Print both preceding and following trees
print("\nInitial Trees:")
extractor.print_trees()


Initial Trees:
Preceding Tree (before <MASK>):
ich (count: 2, id: 2)
  <START> (count: 2, id: 3)
nochmal (count: 1, id: 9)
  <START> (count: 1, id: 10)
wie (count: 1, id: 15)
  <START> (count: 1, id: 16)
wo (count: 1, id: 21)
  <START> (count: 1, id: 22)
vielleicht (count: 1, id: 23)
  <START> (count: 1, id: 24)

Following Tree (after <MASK>):
mich (count: 2, id: 4)
  gut (count: 1, id: 5)
    <END> (count: 1, id: 6)
  nicht (count: 1, id: 7)
    <END> (count: 1, id: 8)
ich (count: 1, id: 11)
  mich (count: 1, id: 12)
    nicht (count: 1, id: 13)
      <END> (count: 1, id: 14)
man (count: 3, id: 17)
  sich (count: 3, id: 18)
    nochmal (count: 2, id: 19)
      <END> (count: 2, id: 20)
    dann (count: 1, id: 25)
      nochmal (count: 1, id: 26)
        <END> (count: 1, id: 27)
