<a href="https://colab.research.google.com/github/dingzhang2023/problem-solving-practice/blob/colab/Trie.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Trie

**04/07/2024**

#### Core idea

> -  Group the string according to their prefixes, where strings with the same first letter are grouped together, and so on.
> -  To ensure that each letter corresponds to a specific node in a tree where each node's parent is the previous node's corresponding letter, `s[i]` is the parent node of `s[i + 1]`.
> - Example of trie, a trie containing the words `apple`, `app`, `banana`, `bat` and `cat`. Here's how the Trie looks:

                  (root)
                  /
                a
                / \
              p   b
              / \   \
            p   p   a
            /     \   \
          l       l   t
          /         \
        p           e
        /
        p
        |
        e

> - Save extra information in each node based on problems, such as `is_word` and `cnt`.


#### Major use cases
The main problems solved by Trie include:

1. String problems related to prefix
2. XOR Number problems, which can be converted to binary representation using trie


#### Trie code template
[208. Implement Trie (Prefix Tree)](https://leetcode.com/problems/implement-trie-prefix-tree/description/)
> Trie template problem
> - Insert
> - Query
> - Extra information saved in each node of trie, for this problem, flag `is_word` saved for id this node is a word or not.

In [None]:
class Node:
    __slot__ = 'son', 'is_word'


    def __init__(self):
        # use array
        # self.son = [None] * 26
        # use defaultdice(Node)
        self.son = defaultdict(Node)
        self.is_word = False


class Trie:

    def __init__(self):
        self.root = Node()


    def insert(self, word: str) -> None:
        cur = self.root
        for c in word:
            cur = cur.son[c]
        cur.is_word = True

    def search(self, word: str) -> bool:
        cur = self.root
        for c in word:
            if c not in cur.son:
                return False
            cur = cur.son[c]
        return cur.is_word

    def startsWith(self, prefix: str) -> bool:
        cur = self.root
        for c in prefix:
            if c not in cur.son:
                return False
            cur = cur.son[c]
        return True

#### Problem lists

[676. Implement Magic Dictionary](https://leetcode.com/problems/implement-magic-dictionary/)
> Keywords: Changing one character, searching for strings in a set of strings. Essentially, it's about counting the number of strings. And it involves multiple queries. Trie is excellent for handling word frequency statistics.

> Essentially, it's about counting some information on the Trie tree, specifically counting whether there is a string that can match when changing one character. Performing a DFS on the tree will suffice.

> Time complexity: $O(C^L)$

> Space complexity: $O(N \times L \times C )$

> Where $L = 100$ is the maximum number of strings stored in the Trie, and $C = 26$ is the size of the character set, given that the maximum length of the stored strings is $L = 100$.

In [None]:
class Node:
    __slots__ = 'son', 'is_word'

    def __init__(self):
        self.son = defaultdict(Node)
        self.is_word = False


class MagicDictionary:

    def __init__(self):
        self.root = Node()

    def insert(self, word):
        cur = self.root
        for c in word:
            cur = cur.son[c]
        cur.is_word = True


    def buildDict(self, dictionary: List[str]) -> None:
        for word in dictionary:
            self.insert(word)


    def search(self, searchWord: str) -> bool:
        # brutal force to try to find a path with updating one char for current char
        def dfs(s, node, i, cnt):
            if i == len(s) and node.is_word and cnt == 1:
                return True

            if i == len(s) or cnt > 1:
                return False

            # try to change 1 char to see if words
            for c in node.son.keys():
                if dfs(s, node.son[c], i + 1, cnt + (c != s[i])):
                    return True
            return False

        return dfs(searchWord, self.root, 0, 0)

[648. Replace Words](https://leetcode.com/problems/replace-words/description/)
> Using the shortest prefix to replace a word is suitable for solving with a Trie.

> Save the index of word in the trie is to find the original word in the dictionary. The first hit of index is the shortest word for the query word.

In [None]:
class Node:
    __slots__ = 'son', 'idx'

    def __init__(self):
        self.son = defaultdict(Node)
        self.idx = -1


class Trie:

    def __init__(self):
        self.root = Node()


    def insert(self, word, idx):
        cur = self.root
        for c in word:
            cur = cur.son[c]
        cur.idx = idx


    def query(self, word):
        cur = self.root
        for c in word:
            cur = cur.son[c]
            if cur.idx != -1:
                return cur.idx
        return -1


class Solution:
    def replaceWords(self, dictionary: List[str], sentence: str) -> str:
        trie = Trie()
        for i, word in enumerate(dictionary):
            trie.insert(word, i)

        ans = []
        for word in sentence.split(' '):
            idx = trie.query(word)
            if idx != -1:
                # replace by shorter word
                word = dictionary[idx]
            ans.append(word)

        return ' '.join(ans)

[1804. Implement Trie (Prefix Tree II)](https://leetcode.com/problems/implement-trie-ii-prefix-tree/description/)
> Trie template problem
> - Insert
> - Query
> - Erase, use lazy delete technique, use `word_cnt` information to each node to check the current node is the end of a word.

In [None]:
class Node:
    __slots__ = 'son', 'prefix_cnt', 'word_cnt'

    def __init__(self):
        self.son = defaultdict(Node)
        self.prefix_cnt = 0
        self.word_cnt = 0

class Trie:

    def __init__(self):
        self.root = Node()


    def insert(self, word: str) -> None:
        cur = self.root
        for c in word:
            cur = cur.son[c]
            cur.prefix_cnt += 1
        cur.word_cnt += 1


    def countWordsEqualTo(self, word: str) -> int:
        cur = self.root
        for c in word:
            if c not in cur.son:
                return 0
            cur = cur.son[c]
        return cur.word_cnt

    def countWordsStartingWith(self, prefix: str) -> int:
        cur = self.root
        for c in prefix:
            if c not in cur.son:
                return 0
            cur = cur.son[c]
        return cur.prefix_cnt

    def erase(self, word: str) -> None:
        cur = self.root
        for c in word:
            if c not in cur.son:
                return
            cur = cur.son[c]
            cur.prefix_cnt -= 1
        cur.word_cnt -= 1

[2416. Sum of Prefix Scores Strings](https://leetcode.com/problems/sum-of-prefix-scores-of-strings/)

> The essence of counting the number of strings passing through a node `x`, where `x` is also a prefix, can be effectively achieved using a variable to store the count of prefixes passing through each node. This application is indeed a classic use case of Trie data structure.

> Insert each word into a Trie tree, then use Depth-First Search (DFS) starting from the root node to calculate the cumulative score of each character node, which is the answer.

> Implementation tricks, using ids list to save the words pass this node.

In [None]:
class Node:
    __slots__ = 'son', 'score', 'ids'

    def __init__(self):
        self.son = dict()
        # self.son = defaultdict(Node)
        self.score = 0
        self.ids = []


class Solution:
    def sumPrefixScores(self, words: List[str]) -> List[int]:

        root = Node()
        # Insert word to trie
        for i, word in enumerate(words):
            cur = root
            for c in word:
                # comment if condition if use defaultdict()
                if c not in cur.son:
                    cur.son[c] = Node()
                cur = cur.son[c]
                cur.score += 1
            cur.ids.append(i) # save the word passing the path

        ans = [0] * len(words)
        def dfs(node, total_score):
            if node is None:
                return

            total_score += node.score
            for i in node.ids:
                ans[i] = total_score
            for child in node.son.values():
                dfs(child, total_score)

        dfs(root, 0)
        return ans

[3045. Count Prefix and Suffix Pairs II](https://leetcode.com/problems/count-prefix-and-suffix-pairs-ii/description/)

**Method I**
> To transform the problem of determining whether one string is a prefix of another into a problem of only checking for prefixes

> - Convert `s` to a list of pairs: [(s[0], s[n-1]), (s[1], s[n-2]),...,(s[n-1], s[0])]
> - Check whether the pair list associated with `words[i]` is a prefix of the pair list associated with `words[j]`
> - Use trie to check the prefix

In [None]:
class Node:
  __slot__ = 'son', 'cnt'

  def __init__(self):
    self.son = dict() # key is pair, value is node
    self.cnt = 0 # the count of the words(pair) ends with this node

class Solution:
    def countPrefixSuffixPairs(self, words: List[str]) -> int:
        ans = 0
        root = Node()
        for s in words:
            cur = root
            for p in zip(s, s[::-1]):
                # p = (s[i], s[n-1-i])
                if p not in cur.son:
                    cur.son[p] = Node()
                cur = cur.son[p]
                ans += cur.cnt
            # update the count of pair ends with s
            cur.cnt += 1
        return ans

Method II
> If `s` is both a prefix and a suffix of `t`, then for `t`, the length `|S|` of its prefix and suffix must be the same. The Z-function is defined as z[i] = LCP(s[i:], s) = n - i
>


[3093. Longest Common Suffix Queries](https://leetcode.com/problems/longest-common-suffix-queries/)

> This problem only matches suffixes. After reversing the string, it becomes prefixes, which can be treated as multiple strings to solve the prefix matching problem. This is a typical use case of Trie.

> Insert string to trie then query.

> Implementation tricks, saving index and length of each word on the node of trie.

> Edge case, empty string.

Using list to implement node of trie

In [None]:
class Node:
    __slots__ = 'son', 'min_len', 'i'

    def __init__(self):
        self.son = [None] * 26
        self.min_l = inf

class Solution:
    def stringIndices(self, wordsContainer: List[str], wordsQuery: List[str]) -> List[int]:
        ord_a = ord('a')
        root = Node()
        for idx, s in enumerate(wordsContainer):
            l = len(s)
            cur = root
            if l < cur.min_len:
                cur.min_len, cur.i = l, idx
            for c in map(ord, reversed(s)):
                c -= ord_a
                if cur.son[c] is None:
                    cur.son[c] = Node()
                cur = cur.son[c]
                # update the current node with the shortest prefix
                if l < cur.min_l:
                    cur.min_l, cur.i = l, idx

        ans = []
        for s in wordsQuery:
            cur = root
            for c in map(ord, reversed(s)):
                c -= ord_a
                if cur.son[c] is None:
                    break
                cur = cur.son[c]
            ans.append(cur.i)
        return ans

Using defaultdict to implement node of trie

In [None]:
class Node:
    __slots__ = 'son', 'min_l', 'i'

    def __init__(self):
        self.son = defaultdict(Node)
        self.min_l = inf

class Solution:
    def stringIndices(self, wordsContainer: List[str], wordsQuery: List[str]) -> List[int]:
        root = Node()
        for idx, s in enumerate(wordsContainer):
            l = len(s)
            cur = root
            if l < cur.min_l:
                cur.min_l, cur.i = l, idx
            for c in reversed(s):
                cur = cur.son[c]
                if l < cur.min_l:
                    cur.min_l, cur.i = l, idx

        ans = []
        for s in wordsQuery:
            cur = root
            for c in reversed(s):
                # prefix does not exist
                if c not in cur.son:
                    break
                cur = cur.son[c]
            ans.append(cur.i)
        return ans

#### `0-1 Trie` for `XOR` problem



#### Template problem
[421. Maximum XOR of Two Numbers in an Array](https://leetcode.com/problems/maximum-xor-of-two-numbers-in-an-array/description/)
> Think each number in binary, for each bit 0 we need to check if there is a number whose current bit is 1, a data structure needed to main 0 and 1 and group each bit together based on 0 or 1.

> Trie comes into play, for each bit from the left(highest one) to right, 0 or 1 can be seperated into left and right subtree, which is a smaller similar problem.

> Each node of trie can be reused, which optimizes the space.

> Important thinking is to use trie to **group** each bit for all numbers from the highest to lowest bit.

In [None]:
class Node:
    __slots__ = 'son'


    def __init__(self):
        # for each bit 0 or 1
        self.son = [None, None]


class Trie:
    HIGH_BIT = 31

    def __init__(self):
        self.root = Node()

    # return the max xor with val in the trie
    # at least one element in the trie
    def insert(self, val):
        cur = self.root
        # insert number from the highest bit
        for i in range(Trie.HIGH_BIT, -1, -1):
            bit = (val >> i) & 1
            if cur.son[bit] is None:
                cur.son[bit] = Node()
            cur = cur.son[bit]


    def max_xor(self, val):
        cur = self.root
        ans = 0
        for i in range(Trie.HIGH_BIT, -1, -1):
            bit = (val >> i) & 1
            # opposit bit exists, take that path
            if cur.son[bit ^ 1]:
                ans |= 1 << i
                bit ^= 1
            # non exists, follow the current path
            cur = cur.son[bit]
        return ans


class Solution:
    def findMaximumXOR(self, nums: List[int]) -> int:
        t = Trie()
        ans = 0
        for x in nums:
            t.insert(x)
            ans = max(ans, t.max_xor(x))
        return ans


[2935.  Maximum Strong Pair XOR II](https://leetcode.com/problems/maximum-strong-pair-xor-ii/description/)

> Because the order of the elements in the answer is irrelevant to the order of nums, we first sort them.

> After sorting, if $x \leq y$, then $|x - y| \leq \min(x, y)$ can be simplified to $2x \geq y$

> This means that for each $y = \text{nums}[i]$
, we need to select $\text{y}$ and its left elements that satisfy $2x \geq y$, XOR with $\text{y}$ to find the maximum XOR sum. This can be implemented using a `0-1 trie`.

> Since the larger $\text{y}$ is, the larger the minimum $\text{x}$ that can be selected, a **sliding window** is needed. Each time an element is slid out, it is removed from the 0-1 trie.

> Removing element from trie, a flag could be used as a lazy delete.

In [None]:
class Node:
    __slots__ = 'son', 'cnt'

    def __init__(self):
        self.son = [None, None] # 0, 1 node
        self.cnt = 0


class Trie:
    HIGH_BIT = 19


    def __init__(self):
        self.root = Node()


    def insert(self, val):
        cur = self.root
        for i in range(Trie.HIGH_BIT, -1, -1):
            bit = (val >> i) & 1
            if cur.son[bit] is None:
                cur.son[bit] = Node()
            cur = cur.son[bit]
            cur.cnt += 1 # maintain size of subtree
        return cur


    def remove(self, val):
        cur = self.root
        for i in range(Trie.HIGH_BIT, -1, -1):
            cur = cur.son[(val >> i) & 1]
            cur.cnt -= 1
        return cur


    def max_xor(self, val):
        cur = self.root
        ans = 0
        for i in range(Trie.HIGH_BIT, -1, -1):
            bit = (val >> i) & 1
            if cur.son[bit ^ 1] and cur.son[bit ^ 1].cnt:
                ans |= 1 << i
                bit ^= 1
            cur = cur.son[bit]
        return ans


class Solution:
    def maximumStrongPairXor(self, nums: List[int]) -> int:
        nums.sort()
        t = Trie()
        ans = left = 0
        for y in nums:
            t.insert(y)
            # sliding window to remove x
            while nums[left] * 2 < y:
                t.remove(nums[left])
                left += 1
            ans = max(ans, t.max_xor(y))
        return ans

[1707.  Maximum XOR With an Element From Array](https://leetcode.com/problems/maximum-xor-with-an-element-from-array/description/)

>

[1803.  Count Pairs With XOR in a Range](https://leetcode.com/problems/count-pairs-with-xor-in-a-range/description/)

>

[1938.  Maximum Genetic Difference Query](https://leetcode.com/problems/maximum-genetic-difference-query/description/)

>

[1938.  Maximum Genetic Difference Query](https://leetcode.com/problems/maximum-xor-of-two-non-overlapping-subtrees/description/)

>