In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from sklearn.neighbors import BallTree
from transformers import AutoTokenizer, AutoModel
import torch

class BERTEmbedder:
    def __init__(self, model_name="sentence-transformers/stsb-bert-base", device=None):
        self.device = device or 'cpu'
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()
        
        self.do_lower_case = getattr(self.tokenizer, 'do_lower_case', False)

    def text_to_embedding(self, texts, pooling='mean', normalize=False):
        is_single = isinstance(texts, str)
        texts = [texts] if is_single else texts
        
        inputs = self.tokenizer(
            texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=128
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            
        if pooling == 'mean':
            mask = inputs['attention_mask'].unsqueeze(-1)
            embeddings = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        elif pooling == 'cls':
            embeddings = outputs.last_hidden_state[:, 0, :]
        else:
            raise ValueError("Invalid pooling method")
            
        if normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            
        return embeddings.cpu().numpy()[0] if is_single else embeddings.cpu().numpy()


class BERTBallTree:
    def __init__(self, embedder=None, metric='euclidean', leaf_size=40):
        """
        Initialize the BERT Ball Tree.
        
        Args:
            embedder: Pre-initialized BERTEmbedder instance
            metric: Distance metric for BallTree ('euclidean', 'cosine', etc.)
            leaf_size: Affects the speed of queries and memory usage
        """
        self.embedder = embedder or BERTEmbedder()
        self.metric = metric
        self.leaf_size = leaf_size
        self.tree = None
        self.texts = None
    
    def build_tree(self, texts):
        """
        Build the Ball Tree from a list of texts.
        
        Args:
            texts: List of strings to index
        """
        self.texts = np.array(texts)  # Store the original texts
        embeddings = self.embedder.text_to_embedding(texts, pooling='mean', normalize=True)
        self.tree = BallTree(embeddings, metric=self.metric, leaf_size=self.leaf_size)
    
    def query(self, query_text, k=5, return_distances=False):
        """
        Query the Ball Tree for nearest neighbors.
        
        Args:
            query_text: The query text string
            k: Number of nearest neighbors to return
            return_distances: Whether to return distances along with results
            
        Returns:
            If return_distances is False: list of nearest texts
            If return_distances is True: tuple of (texts, distances)
        """
        if self.tree is None:
            raise ValueError("Ball Tree has not been built yet. Call build_tree() first.")
            
        # Get embedding for the query text
        query_embedding = self.embedder.text_to_embedding(
            query_text, pooling='mean', normalize=True
        ).reshape(1, -1)
        
        # Query the tree
        distances, indices = self.tree.query(query_embedding, k=k)
        
        # Get the corresponding texts
        results = self.texts[indices[0]]
        
        if return_distances:
            return results, distances[0]
        return results
    
    def save_tree(self, filepath):
        """Save the Ball Tree and associated data to disk."""
        import joblib
        data = {
            'texts': self.texts,
            'tree': self.tree,
            'metric': self.metric,
            'leaf_size': self.leaf_size
        }
        joblib.dump(data, filepath)
    
    @classmethod
    def load_tree(cls, filepath, embedder=None):
        """Load a saved Ball Tree from disk."""
        import joblib
        data = joblib.load(filepath)
        instance = cls(embedder=embedder, metric=data['metric'], leaf_size=data['leaf_size'])
        instance.texts = data['texts']
        instance.tree = data['tree']
        return instance


# Example usage
if __name__ == "__main__":
    # Sample texts
    texts = [
        "The quick brown fox jumps over the lazy dog",
        "Artificial intelligence is transforming industries",
        "Python is a popular programming language",
        "Machine learning requires large amounts of data",
        "Deep learning models use neural networks",
        "Natural language processing helps computers understand text",
        "The weather is nice today",
        "I enjoy reading books in my free time"
    ]
    
    # Initialize and build the tree
    ball_tree = BERTBallTree()
    ball_tree.build_tree(texts)
    
    # Query the tree
    query = "computer understanding of human language"
    results, distances = ball_tree.query(query, k=3, return_distances=True)
    
    print(f"Query: {query}")
    print("Top 3 results:")
    for text, dist in zip(results, distances):
        print(f"- {text} (distance: {dist:.4f})")

  from .autonotebook import tqdm as notebook_tqdm


Query: computer understanding of human language
Top 3 results:
- Natural language processing helps computers understand text (distance: 0.7042)
- Machine learning requires large amounts of data (distance: 0.9758)
- Artificial intelligence is transforming industries (distance: 1.0927)


In [2]:
import math
import os
import re
import shutil
from collections import Counter
import numpy as np
import nltk
from gensim.models import Word2Vec
from annoy import AnnoyIndex
from nltk.corpus import stopwords
from tqdm import tqdm

from utils import from_current_file, load_json, round_float, save_json

nltk.download("stopwords")
nltk.download("punkt_tab")

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Kiaver\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Kiaver\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [10]:
class Word2VecIndexer:
    _stop_words = set(stopwords.words("english"))

    def __init__(
        self,
        index_dir: str = "../data/embedding_directory",
        documents_dir: str = "../data/scrapped/class_data_function__1_1",
        top_similar: int = 10,
        force: bool = False,
    ):
        self._index_dir = from_current_file(index_dir)
        self._documents_dir = from_current_file(documents_dir)
        self.top_similar = top_similar

        self._word2vec_model_path = os.path.join(self._index_dir, "word2vec.model")
        self._annoy_index_path = os.path.join(self._index_dir, "doc_embeddings.ann")
        self.doc_embeddings: dict[int, np.ndarray] = {}  # Document ID -> embedding
        self.annoy_index: AnnoyIndex = None  # Annoy index for document embeddings
        self._doc_id_path = os.path.join(self._index_dir, "documents.json")
        self.documents: dict[int, str] = {}

        self.model: Word2Vec = None

        if force or not os.path.exists(self._index_dir):
            print("Index is not found, creating new...")
            if force:
                try:
                    shutil.rmtree(self._index_dir)
                except FileNotFoundError:
                    pass
            os.mkdir(path=self._index_dir)
            self.build_index()
            print("Complete!")

        self.load_index()

    def _tokenize(self, text: str) -> list[str]:
        return [w for w in re.findall(r"\w+", text.lower()) if w not in self._stop_words]

    def _get_similar_words(self, word: str) -> set[tuple[str, float]]:
        matches = set()
        if self.model and word in self.model.wv:
            for similar_word, similarity in self.model.wv.most_similar(
                word, topn=self.top_similar, indexer=self.annoy_indexer
            ):
                if similar_word in self.index:
                    matches.add((similar_word, similarity))
        return matches

    def build_index(self):
        sentences = []
        for document_id, filename in enumerate(os.listdir(self._documents_dir)):
            if filename.endswith(".txt"):
                with open(
                    os.path.join(self._documents_dir, filename), "r", encoding="utf-8"
                ) as f:
                    text = f.read()
                    self.documents[document_id] = filename[:-4]
                    words = self._tokenize(text)
                    sentences.append(words)

        self.model = Word2Vec(
            sentences=sentences,
            min_count = 1,
        )
        vector_size = self.model.vector_size

        self.doc_embeddings = {
            doc_id: np.mean([
                self.model.wv[word] 
                for word in words 
                if word in self.model.wv
            ], axis=0) 
            for doc_id, words in enumerate(sentences)
        }

        # self.doc_embeddings = {}
        # for doc_id, words in tqdm(enumerate(sentences)):
        #     self.doc_embeddings[doc_id] = np.mean([
        #         self.model.wv[word] 
        #         for word in words 
        #         if word in self.model.wv
        #     ], axis=0) 

        # Build Annoy index for documents
        self.annoy_index = AnnoyIndex(vector_size, 'angular')
        for doc_id, embedding in self.doc_embeddings.items():
            self.annoy_index.add_item(doc_id, embedding)
        self.annoy_index.build(n_trees=1000)
        self.annoy_index.save(self._annoy_index_path)

        # Persist model and index
        self.model.save(self._word2vec_model_path)

        save_json(self._doc_id_path, self.documents)

    def load_index(self):
        self.documents = {int(k): v for k, v in load_json(self._doc_id_path).items()}
        self.model = Word2Vec.load(self._word2vec_model_path)
        vector_size = self.model.vector_size
        self.annoy_index = AnnoyIndex(vector_size, 'angular')
        self.annoy_index.load(self._annoy_index_path)

    def find(self, query: str, top_k: int = 10) -> list:
        query_words = self._tokenize(query)
        query_vectors = [
            self.model.wv[word] 
            for word in query_words 
            if word in self.model.wv
        ]
        
        if not query_vectors:
            return []
            
        # Average word vectors for query embedding
        query_embedding = np.mean(query_vectors, axis=0)
        
        # Find similar documents using Annoy
        doc_ids, distances = self.annoy_index.get_nns_by_vector(
            query_embedding, 
            top_k, 
            include_distances=True
        )
        
        # Convert angular distances to cosine similarities
        results = []
        for doc_id, distance in zip(doc_ids, distances):
            cosine_sim = 1 - (distance ** 2) / 2  # Convert angular distance to cosine
            results.append((doc_id, cosine_sim))
        
        return [
            (self.documents[doc_id], round_float(score, 5))
            for doc_id, score in sorted(results, key=lambda x: -x[1])
        ]


indexer = Word2VecIndexer()
results = indexer.find('''url''', top_k=10)
for doc, score in results:
    print(f"Score: {score}\tFile: {doc}")

Score: 0.99148	File: math.tau
Score: 0.98984	File: csv.QUOTE_MINIMAL
Score: 0.9872	File: mimetypes.guess_extension
Score: 0.98673	File: io.BytesIO.getvalue
Score: 0.98517	File: urllib.parse.SplitResultBytes
Score: 0.98508	File: math.trunc
Score: 0.98434	File: ast.Interactive
Score: 0.98432	File: subprocess.STARTF_USESTDHANDLES
Score: 0.98393	File: os.sysconf_names
Score: 0.98296	File: functools.singledispatchmethod
