In [None]:
from g4f.client import Client

In [None]:
import os

import torch
from sklearn.neighbors import BallTree
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from utils import from_current_file, load

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DEVICE

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class BERTEmbedder:
    def __init__(
        self, model_name="sentence-transformers/msmarco-bert-base-dot-v5", device=DEVICE
    ):
        self.device = device or "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

        self.do_lower_case = getattr(self.tokenizer, "do_lower_case", False)

    def text_to_embedding(self, texts, pooling="mean", normalize=False):
        is_single = isinstance(texts, str)
        texts = [texts] if is_single else texts

        inputs = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True, max_length=512
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        if pooling == "mean":
            mask = inputs["attention_mask"].unsqueeze(-1)
            embeddings = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(
                min=1e-9
            )
        elif pooling == "cls":
            embeddings = outputs.last_hidden_state[:, 0, :]
        else:
            raise ValueError("Invalid pooling method")

        if normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings.cpu().numpy()[0] if is_single else embeddings.cpu().numpy()


def batch(iterable, batch_size):
    all_batches = []
    current_batch = []

    for item in iterable:
        current_batch.append(item)
        if len(current_batch) == batch_size:
            all_batches.append(current_batch)
            current_batch = []

    if current_batch:  # Add the last partial batch
        all_batches.append(current_batch)

    return all_batches


class BERTBallTree:
    def __init__(
        self,
        index_dir: str = "../data/embedding_directory",
        documents_dir: str = "../data/scrapped/class_data_function__1_1",
        top_similar: int = 10,
        force: bool = False,
        embedder=None,
        metric="euclidean",
        leaf_size=1,
    ):
        """
        Initialize the BERT Ball Tree.

        Args:
            embedder: Pre-initialized BERTEmbedder instance
            metric: Distance metric for BallTree ('euclidean', 'cosine', etc.)
            leaf_size: Affects the speed of queries and memory usage
        """
        self._index_dir = from_current_file(index_dir)
        self._documents_dir = from_current_file(documents_dir)
        self.top_similar = top_similar
        self.embedder = embedder or BERTEmbedder()
        self.metric = metric
        self.leaf_size = leaf_size
        self.tree = None
        self.texts = None
        self.documents: dict[int, str] = {}

    def build_tree(self):
        """
        Build the Ball Tree from a list of texts.

        Args:
            texts: List of strings to index
        """
        sentences = []
        for document_id, filename in enumerate(os.listdir(self._documents_dir)):
            if filename.endswith(".txt"):
                with open(
                    os.path.join(self._documents_dir, filename), "r", encoding="utf-8"
                ) as f:
                    text = f.read()
                    self.documents[document_id] = filename[:-4]
                    sentences.append(text)

        embeddings = []
        for b in tqdm(batch(sentences, 64)):
            batch_embeddings = self.embedder.text_to_embedding(
                b, pooling="mean", normalize=True
            )
            embeddings.extend(batch_embeddings)
        print(len(embeddings))
        self.tree = BallTree(embeddings, metric=self.metric, leaf_size=self.leaf_size)

    def query(self, query_text, k=5, return_distances=False):
        """
        Query the Ball Tree for nearest neighbors.

        Args:
            query_text: The query text string
            k: Number of nearest neighbors to return
            return_distances: Whether to return distances along with results

        Returns:
            If return_distances is False: list of nearest texts
            If return_distances is True: tuple of (texts, distances)
        """
        if self.tree is None:
            raise ValueError("Ball Tree has not been built yet. Call build_tree() first.")

        # Get embedding for the query text
        query_embedding = self.embedder.text_to_embedding(
            query_text, pooling="mean", normalize=True
        ).reshape(1, -1)

        # Query the tree
        distances, indices = self.tree.query(query_embedding, k=k)

        # Get the corresponding texts
        print(indices[0])
        results = []
        for indice in indices[0]:
            results.append(self.documents[indice])

        if return_distances:
            return indices[0], results, distances
        return results

    def save_tree(self, filepath):
        """Save the Ball Tree and associated data to disk."""
        import joblib

        data = {
            "texts": self.texts,
            "tree": self.tree,
            "metric": self.metric,
            "leaf_size": self.leaf_size,
        }
        joblib.dump(data, filepath)

    @classmethod
    def load_tree(cls, filepath, embedder=None):
        """Load a saved Ball Tree from disk."""
        import joblib

        data = joblib.load(filepath)
        instance = cls(
            embedder=embedder, metric=data["metric"], leaf_size=data["leaf_size"]
        )
        instance.texts = data["texts"]
        instance.tree = data["tree"]
        return instance


# Initialize and build the tree
ball_tree = BERTBallTree()
ball_tree.build_tree()

# Query the tree
# query = "sin"
# results, distances = ball_tree.query(query, k=10, return_distances=True)

# print(f"Query: {query}")
# print("Top 3 results:")
# for text, dist in zip(results, distances):
#     print(f"- {text} (distance: {dist:.4f})")

### ADD INDEXER/

In [None]:
class RAG:
    folder_path = "../data/scrapped/class_data_function__1_1"

    def __init__(self, model: str):
        self.client = Client()

    def _retrive_docs(self, query_m: str, k: int = 5):
        index, results, _ = ball_tree.query(query_m, k)

        contents = []

        for name in results:
            content = load(os.path.join(self.folder_path, name))

            contents.append(content)

        return contents

    def get_answer(self, question: str, model: str, k: int = 5):
        self.model = model

        context = self._retrive_docs(question)

        prompt = (
            f"You're a Python expert. Answer strictly according to the documentation wich is marked as 'Context' below. "
            'If there is no answer in the context, say, "I can\'t find the answer in the Python documentation"\n'
            "\nContext:\n"
            f"{'\n\n'.join([f'{idx + 1}. {c}' for idx, c in enumerate(context)])}\n"
            "\nQuestion: \n"
            "\nResponse (with reference to the source [1-{2}]):\n"
        )

        messages = [{"role": "user", "content": prompt}]

        response = self.client.chat.completions.create(
            model=self.model, messages=messages, web_search=False
        )

        return response.choices[0].message.content

In [None]:
rag_model = RAG()

question = "Sin and Cos"

rag_model.get_answer(question=question, model="gpt4gpt-4o-mini")

In [3]:
context = ["fskdfj", "fs;ldgkjlk"]

In [15]:
print(prompt)

You're a Python expert. Answer strictly according to the documentation wich is marked as 'Context' below. If there is no answer in the context, say, "I can't find the answer in the Python documentation"

Context:
1. fskdfj

2. fs;ldgkjlk

Question: 

Response (with reference to the source [1-{2}]):

