In [None]:
import numpy as np
import faiss
from typing import List, Tuple
import scipy.sparse as sp
from sklearn.preprocessing import normalize

# Constants
K = 100  # example 
VP_SIZE = 25000  # vocabulary size of public set

def build_candidate_vocab(corpus_tokens: List[str], vocab_size: int) -> Tuple[List[str], dict]:
    from collections import Counter
    counter = Counter(corpus_tokens)
    most_common = counter.most_common(vocab_size)
    vocab = [word for word, _ in most_common]
    word2idx = {word: i for i, word in enumerate(vocab)}
    return vocab, word2idx

def load_glove_embeddings(glove_path: str, vocab: List[str]) -> np.ndarray:
    emb_dim = 300
    embeddings = np.zeros((len(vocab), emb_dim), dtype='float32')
    vocab_set = set(vocab)
    found = 0
    with open(glove_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split()
            word = parts[0]
            if word in vocab_set:
                idx = vocab.index(word)
                vec = np.array([float(x) for x in parts[1:]], dtype='float32')
                embeddings[idx] = vec
                found += 1
    print(f'Found {found} GloVe vectors for vocab words.')
    return embeddings

def compute_knn(embeddings: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)
    D, I = index.search(embeddings, k + 1)
    return D[:, 1:], I[:, 1:]

def estimate_epsilon(D_knn: np.ndarray, freq_max: float) -> float:
    def theoretical_self_prob(eps):
        Z = np.exp(-0.5 * eps * np.sqrt(D_knn)).sum(axis=1)
        return 1.0 / np.mean(Z)

    eps_range = np.linspace(0.1, 30.0, 500)
    best_eps = min(eps_range, key=lambda e: abs(theoretical_self_prob(e) - freq_max))
    return best_eps

def build_likelihood_matrix(D_knn: np.ndarray, I_knn: np.ndarray, eps: float) -> sp.csr_matrix:
    n = D_knn.shape[0]
    rows, cols, vals = [], [], []

    for i in range(n):
        neighbors = I_knn[i]
        distances = np.sqrt(D_knn[i])
        probs = np.exp(-0.5 * eps * distances)
        probs /= probs.sum()
        for j, p in zip(neighbors, probs):
            rows.append(i)
            cols.append(j)
            vals.append(p)

    L = sp.coo_matrix((vals, (rows, cols)), shape=(n, n))
    return L.tocsr()

def decode_indices(obf_indices: List[int], L: sp.csr_matrix) -> List[int]:
    decoded = []
    for idx in obf_indices:
        candidates = L[:, idx].toarray().flatten()
        decoded_idx = np.argmax(candidates)
        decoded.append(decoded_idx)
    return decoded

def santext_attack_pipeline(corpus_tokens: List[str], glove_path: str, obf_indices: List[int], freq_max: float):
    vocab, word2idx = build_candidate_vocab(corpus_tokens, VP_SIZE)
    embeddings = load_glove_embeddings(glove_path, vocab)
    D_knn, I_knn = compute_knn(embeddings, K)
    eps_hat = estimate_epsilon(D_knn, freq_max)
    print(f"Estimated epsilon: {eps_hat:.3f}")
    L = build_likelihood_matrix(D_knn, I_knn, eps_hat)
    decoded = decode_indices(obf_indices, L)
    decoded_words = [vocab[i] for i in decoded]
    return decoded_words, eps_hat
