In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
from sklearn.neighbors import BallTree
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from utils import from_current_file

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device('cpu')
DEVICE

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [3]:
class BERTEmbedder:
    def __init__(
        self, model_name="sentence-transformers/msmarco-bert-base-dot-v5", device=DEVICE
    ):
        self.device = device or "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

        self.do_lower_case = getattr(self.tokenizer, "do_lower_case", False)

    def text_to_embedding(self, texts, pooling="mean", normalize=False):
        is_single = isinstance(texts, str)
        texts = [texts] if is_single else texts

        inputs = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True, max_length=512
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        if pooling == "mean":
            mask = inputs["attention_mask"].unsqueeze(-1)
            embeddings = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(
                min=1e-9
            )
        elif pooling == "cls":
            embeddings = outputs.last_hidden_state[:, 0, :]
        else:
            raise ValueError("Invalid pooling method")

        if normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings.cpu().numpy()[0] if is_single else embeddings.cpu().numpy()


def batch(iterable, batch_size):
    all_batches = []
    current_batch = []

    for item in iterable:
        current_batch.append(item)
        if len(current_batch) == batch_size:
            all_batches.append(current_batch)
            current_batch = []

    if current_batch:  # Add the last partial batch
        all_batches.append(current_batch)

    return all_batches


class BERTBallTree:
    def __init__(
        self,
        index_dir: str = "../data/embedding_directory",
        documents_dir: str = "../data/scrapped/class_data_function__1_1",
        top_similar: int = 10,
        force: bool = False,
        embedder=None,
        metric="euclidean",
        leaf_size=1,
    ):
        """
        Initialize the BERT Ball Tree.

        Args:
            embedder: Pre-initialized BERTEmbedder instance
            metric: Distance metric for BallTree ('euclidean', 'cosine', etc.)
            leaf_size: Affects the speed of queries and memory usage
        """
        self._index_dir = from_current_file(index_dir)
        self._documents_dir = from_current_file(documents_dir)
        self.top_similar = top_similar
        self.embedder = embedder or BERTEmbedder()
        self.metric = metric
        self.leaf_size = leaf_size
        self.tree = None
        self.texts = None
        self.documents: dict[int, str] = {}

    def build_tree(self):
        """
        Build the Ball Tree from a list of texts.

        Args:
            texts: List of strings to index
        """
        sentences = []
        for document_id, filename in enumerate(os.listdir(self._documents_dir)):
            if filename.endswith(".txt"):
                with open(
                    os.path.join(self._documents_dir, filename), "r", encoding="utf-8"
                ) as f:
                    text = f.read()
                    self.documents[document_id] = filename[:-4]
                    sentences.append(text)

        embeddings = []
        for b in tqdm(batch(sentences, 64)):
            batch_embeddings = self.embedder.text_to_embedding(
                b, pooling="mean", normalize=True
            )
            embeddings.extend(batch_embeddings)
        print(len(embeddings))
        self.tree = BallTree(embeddings, metric=self.metric, leaf_size=self.leaf_size)

    def query(self, query_text, k=5, return_distances=False):
        """
        Query the Ball Tree for nearest neighbors.

        Args:
            query_text: The query text string
            k: Number of nearest neighbors to return
            return_distances: Whether to return distances along with results

        Returns:
            If return_distances is False: list of nearest texts
            If return_distances is True: tuple of (texts, distances)
        """
        if self.tree is None:
            raise ValueError("Ball Tree has not been built yet. Call build_tree() first.")

        # Get embedding for the query text
        query_embedding = self.embedder.text_to_embedding(
            query_text, pooling="mean", normalize=True
        ).reshape(1, -1)

        # Query the tree
        distances, indices = self.tree.query(query_embedding, k=k)

        # Get the corresponding texts
        print(indices[0])
        results = []
        for indice in indices[0]:
            results.append(self.documents[indice])

        if return_distances:
            return results, distances[0]
        return results

    def save_tree(self, filepath):
        """Save the Ball Tree and associated data to disk."""
        import joblib

        data = {
            "texts": self.texts,
            "tree": self.tree,
            "metric": self.metric,
            "leaf_size": self.leaf_size,
        }
        joblib.dump(data, filepath)

    @classmethod
    def load_tree(cls, filepath, embedder=None):
        """Load a saved Ball Tree from disk."""
        import joblib

        data = joblib.load(filepath)
        instance = cls(
            embedder=embedder, metric=data["metric"], leaf_size=data["leaf_size"]
        )
        instance.texts = data["texts"]
        instance.tree = data["tree"]
        return instance


# Initialize and build the tree
ball_tree = BERTBallTree()
ball_tree.build_tree()

# Query the tree
query = "sin"
results, distances = ball_tree.query(query, k=10, return_distances=True)

print(f"Query: {query}")
print("Top 3 results:")
for text, dist in zip(results, distances):
    print(f"- {text} (distance: {dist:.4f})")

100%|██████████| 90/90 [00:35<00:00,  2.55it/s]


5711
[2704  388 2705 2702 2655 4090  389 4091 2666 2664]
Query: sin
Top 3 results:
- math.pow (distance: 0.4052)
- cmath.sin (distance: 0.4420)
- math.prod (distance: 0.4420)
- math.perm (distance: 0.4646)
- marshal.loads (distance: 0.4661)
- stat.ST_ATIME (distance: 0.4671)
- cmath.sinh (distance: 0.4679)
- stat.ST_CTIME (distance: 0.4727)
- math.comb (distance: 0.4744)
- math.cbrt (distance: 0.4775)


In [4]:
query = "math"
results, distances = ball_tree.query(query, k=10, return_distances=True)

print(f"Query: {query}")
print("Top 3 results:")
for text, dist in zip(results, distances):
    print(f"- {text} (distance: {dist:.4f})")

[2702 2666 2668 2704 2699 2691 2679 2706 2664 2693]
Query: math
Top 3 results:
- math.perm (distance: 0.4053)
- math.comb (distance: 0.4101)
- math.cos (distance: 0.4238)
- math.pow (distance: 0.4260)
- math.modf (distance: 0.4261)
- math.isqrt (distance: 0.4286)
- math.floor (distance: 0.4314)
- math.radians (distance: 0.4337)
- math.cbrt (distance: 0.4351)
- math.ldexp (distance: 0.4401)


In [8]:
query = "sqrt"
results, distances = ball_tree.query(query, k=10, return_distances=True)

print(f"Query: {query}")
print("Top 3 results:")
for text, dist in zip(results, distances):
    print(f"- {text} (distance: {dist:.4f})")

[2706  390 1133 1187 2687 2660 2681 5187 3052 5220]
Query: sqrt
Top 3 results:
- math.radians (distance: 0.4383)
- cmath.sqrt (distance: 0.4817)
- decimal.Context.sqrt (distance: 0.4834)
- decimal.Decimal.sqrt (distance: 0.4943)
- math.isclose (distance: 0.5119)
- math.asinh (distance: 0.5162)
- math.frexp (distance: 0.5241)
- unittest.IsolatedAsyncioTestCase.enterAsyncContext (distance: 0.5290)
- os.CLONE_NEWIPC (distance: 0.5312)
- unittest.mock.mock_open (distance: 0.5315)


In [6]:
query = "turtle"
results, distances = ball_tree.query(query, k=10, return_distances=True)

print(f"Query: {query}")
print("Top 3 results:")
for text, dist in zip(results, distances):
    print(f"- {text} (distance: {dist:.4f})")

[5059 5042 5049 5030 4988 5026 4993 5068 4986 5069]
Query: turtle
Top 3 results:
- turtle.title (distance: 0.3896)
- turtle.setup (distance: 0.4083)
- turtle.shapetransform (distance: 0.4142)
- turtle.resetscreen (distance: 0.4162)
- turtle.forward (distance: 0.4169)
- turtle.radians (distance: 0.4235)
- turtle.getturtle (distance: 0.4242)
- turtle.up (distance: 0.4262)
- turtle.fillcolor (distance: 0.4293)
- turtle.update (distance: 0.4297)


In [9]:
query = "pi"
results, distances = ball_tree.query(query, k=10, return_distances=True)

print(f"Query: {query}")
print("Top 3 results:")
for text, dist in zip(results, distances):
    print(f"- {text} (distance: {dist:.4f})")

[2699  385 3153 3320 3321 2702  739 3362 3351 2666]
Query: pi
Top 3 results:
- math.modf (distance: 0.4330)
- cmath.pi (distance: 0.4586)
- os.getpriority (distance: 0.4877)
- os.P_NOWAIT (distance: 0.4890)
- os.P_NOWAITO (distance: 0.4949)
- math.perm (distance: 0.4960)
- curses.ACS_PI (distance: 0.4992)
- os.SCHED_RESET_ON_FORK (distance: 0.5058)
- os.scandir (distance: 0.5064)
- math.comb (distance: 0.5071)
