In [None]:
# trying KeyBERT for keyword extraction
from keybert import KeyBERT

In [None]:
class DocKeywordIndex:
    """
    Build once per epoch (or cache to disk).
    Produces doc-level keyphrases and a doc embedding for quick re-scoring.
    """
    def __init__(self, 
                 doc_texts: Dict[str, str],  # doc_id -> full text
                 keybert_model: str = "all-MiniLM-L6-v2",
                 top_n: int = 40,
                 ngram_range=(1,3),
                 use_mmr: bool = True,
                 diversity: float = 0.6):
        self.kb = KeyBERT(model=keybert_model)
        self.embedder = SentenceTransformer(keybert_model)
        self.doc_keywords: Dict[str, List[str]] = {}
        self.doc_embs: Dict[str, np.ndarray] = {}
        self.ngram_range = ngram_range

        for doc_id, text in doc_texts.items():
            # KeyBERT over the full doc
            kws = self.kb.extract_keywords(
                text,
                keyphrase_ngram_range=ngram_range,
                stop_words="english",
                use_mmr=use_mmr,
                diversity=diversity,
                top_n=top_n
            )
            self.doc_keywords[doc_id] = [k for (k, score) in kws]
            # store doc embedding for quick cosine scoring later
            self.doc_embs[doc_id] = self.embedder.encode([text], normalize_embeddings=True)[0]