In [None]:
import re
from collections import defaultdict
from typing import List, Set, Dict, Tuple

def tokenize(text: str) -> List[str]:
    text = text.lower()
    text = re.sub(r"[^a-z0-9']+", " ", text)
    return [t for t in text.split() if t]

def kgrams(term: str, k: int = 2, use_boundaries: bool = False) -> Set[str]:
    t = ('^' + term + '$') if use_boundaries else term
    if len(t) < k:
        return {t}
    return {t[i:i+k] for i in range(len(t)-k+1)}

def build_kgram_index_from_docs(documents: List[str], k: int = 2, use_boundaries: bool = False) -> Tuple[Dict[str, Set[str]], Set[str]]:
    vocab = set()
    for doc in documents:
        vocab.update(tokenize(doc))
    index = defaultdict(set)
    for term in vocab:
        for kg in kgrams(term, k, use_boundaries):
            index[kg].add(term)
    return dict(index), vocab

def generate_candidates_from_index(query_word: str, index: Dict[str, Set[str]], k: int = 2, use_boundaries: bool = False, max_candidates: int = 1000) -> Set[str]:
    q_k = kgrams(query_word, k, use_boundaries)
    candidates = set()
    for kg in q_k:
        candidates.update(index.get(kg, ()))
        if len(candidates) >= max_candidates:
            break
    candidates.discard(query_word)
    return candidates

def jaccard_score(q_k: Set[str], c_k: Set[str]) -> float:
    inter = q_k & c_k
    union = q_k | c_k
    return len(inter) / len(union) if union else 0.0

def correct_query(documents: List[str], query: str, k: int = 2, use_boundaries: bool = False, jaccard_threshold: float = 0.1, top_k_candidates: int = 1) -> Tuple[str, Dict[str, List[Tuple[str,float]]]]:
    index, vocab = build_kgram_index_from_docs(documents, k, use_boundaries)
    tokens = tokenize(query)
    corrected = []
    diagnostics = {}
    for tok in tokens:
        if tok in vocab:
            corrected.append(tok)
            diagnostics[tok] = [(tok, 1.0)]
            continue
        q_k = kgrams(tok, k, use_boundaries)
        candidates = generate_candidates_from_index(tok, index, k, use_boundaries)
        scored = []
        for cand in candidates:
            c_k = kgrams(cand, k, use_boundaries)
            j = jaccard_score(q_k, c_k)
            if j >= jaccard_threshold:
                scored.append((cand, j))
        scored.sort(key=lambda x: (-x[1], len(x[0]), x[0]))
        top = scored[:top_k_candidates]
        diagnostics[tok] = top
        corrected.append(top[0][0] if top else tok)
    return " ".join(corrected), diagnostics

if __name__ == '__main__':
    docs = [
        "The cat sat on the mat",
        "The dog barked at the cat",
        "The quick brown fox jumped over the lazy dog",
        "Cats and dogs are common pets",
        "I love my dog and my cat",
        "Python is a popular programming language",
        "Machine learning is a subset of artificial intelligence",
        "I ate an apple every day",
        "She likes to eat a pea and an ape sometimes",
        "He bought an applet and an apple",
        "We visited the chapel near the apple orchard",
        "They drank a frappe at the cafe"
    ]
    queries = [
        "I ate an appe every day",
        "She saw an aple and an ape",
        "He visited the chapell nearby",
        "I like machne learning",
        "I want a frape drink"
    ]
    for q in queries:
        corrected, diag = correct_query(docs, q, k=2, use_boundaries=False, jaccard_threshold=0.2, top_k_candidates=1)
        print("Query:   ", q)
        print("Corrected:", corrected)
        print("Candidates per token:")
        for tok, candlist in diag.items():
            print("  ", tok, "->", candlist)
        print("-"*60)
