In [1]:
import os

import torch
from sklearn.neighbors import BallTree
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from utils import from_current_file, load

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DEVICE

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [2]:
class BERTEmbedder:
    def __init__(
        self, model_name="sentence-transformers/msmarco-bert-base-dot-v5", device=DEVICE
    ):
        self.device = device or "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

        self.do_lower_case = getattr(self.tokenizer, "do_lower_case", False)

    def text_to_embedding(self, texts, pooling="mean", normalize=False):
        is_single = isinstance(texts, str)
        texts = [texts] if is_single else texts

        inputs = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True, max_length=512
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        if pooling == "mean":
            mask = inputs["attention_mask"].unsqueeze(-1)
            embeddings = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(
                min=1e-9
            )
        elif pooling == "cls":
            embeddings = outputs.last_hidden_state[:, 0, :]
        else:
            raise ValueError("Invalid pooling method")

        if normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings.cpu().numpy()[0] if is_single else embeddings.cpu().numpy()


def batch(iterable, batch_size):
    all_batches = []
    current_batch = []

    for item in iterable:
        current_batch.append(item)
        if len(current_batch) == batch_size:
            all_batches.append(current_batch)
            current_batch = []

    if current_batch:  # Add the last partial batch
        all_batches.append(current_batch)

    return all_batches


class BERTBallTree:
    def __init__(
        self,
        index_dir: str = "../data/embedding_directory",
        documents_dir: str = "../data/scrapped/class_data_function__1_1",
        embedder=None,
        metric="euclidean",
        leaf_size=1,
        tree_name="tree",
        force=False,
    ):
        self._index_dir = from_current_file(index_dir)
        self._documents_dir = from_current_file(documents_dir)
        self.embedder = embedder or BERTEmbedder()
        self.metric = metric
        self.leaf_size = leaf_size
        self.tree = None
        self.tree_name = tree_name
        self.documents: dict[int, str] = {}

        if force or not os.path.exists(self._index_dir):
            print("Tree is not found, creating new...")
            if force:
                try:
                    shutil.rmtree(self._index_dir)
                except FileNotFoundError:
                    pass
            os.mkdir(path=self._index_dir)
            self.build_tree()
            print("Complete!")

        self.load_tree()

    def build_tree(self):
        sentences = []
        for document_id, filename in enumerate(os.listdir(self._documents_dir)):
            if filename.endswith(".txt"):
                with open(
                    os.path.join(self._documents_dir, filename), "r", encoding="utf-8"
                ) as f:
                    text = f.read()
                    self.documents[document_id] = filename[:-4]
                    sentences.append(text)

        embeddings = []
        for b in tqdm(batch(sentences, 64)):
            batch_embeddings = self.embedder.text_to_embedding(
                b, pooling="mean", normalize=True
            )
            embeddings.extend(batch_embeddings)
        self.tree = BallTree(embeddings, metric=self.metric, leaf_size=self.leaf_size)
        self.save_tree()

    def query(self, query_text, k=5, return_distances=False):
        """
        Query the Ball Tree for nearest neighbors.

        Args:
            query_text: The query text string
            k: Number of nearest neighbors to return
            return_distances: Whether to return distances along with results

        Returns:
            If return_distances is False: list of nearest texts
            If return_distances is True: tuple of (texts, distances)
        """
        if self.tree is None:
            raise ValueError("Ball Tree has not been built yet. Call build_tree() first.")

        # Get embedding for the query text
        query_embedding = self.embedder.text_to_embedding(
            query_text, pooling="mean", normalize=True
        ).reshape(1, -1)

        # Query the tree
        distances, indices = self.tree.query(query_embedding, k=k)

        # Get the corresponding texts
        results = []
        for indice in indices[0]:
            results.append(self.documents[indice])

        if return_distances:
            return results, distances[0]
        return results

    def save_tree(self):
        """Save the Ball Tree and associated data to disk."""
        import joblib

        data = {
            "tree": self.tree,
            "documents": self.documents,
        }
        joblib.dump(data, os.path.join(self._index_dir, self.tree_name))

    def load_tree(self):
        """Load a saved Ball Tree from disk."""
        import joblib

        data = joblib.load(os.path.join(self._index_dir, self.tree_name))
        self.tree = data["tree"]
        self.documents = data["documents"]


# Initialize and build the tree
ball_tree = BERTBallTree()

### ADD INDEXER/

In [3]:
import json

import torch
from g4f.client import Client
from transformers import AutoModelForCausalLM


class RAG:
    folder_path = "../data/scrapped/class_data_function__1_1"

    def __init__(self):
        self.client = Client()
        model_id = "microsoft/Phi-3-mini-4k-instruct"
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, device_map="cuda", torch_dtype="auto", trust_remote_code=True
        )

    def generate_stream(self, question: str, model: str, k: int = 10):
        res, context = self._retrive_docs(question, k)
        prompt = (
            "You're a Python expert. Suppose all the documentation information you know is provided in context section. "
            "Answer on the question as usual but take technical information only from context. "
            'If there is no answer in the context, say, "I can\'t find the answer in the Python documentation"\n'
            "Highlight cited passages or provide “show sources” toggles ONLY FROM CONTEXT\n"
            "NO PYTHON CODE EXAMPLES!\n"
            "NO PROMPT REPETITION IN ANSWER!\n"
            "NO EXAMPLES!\n"
            "NO ADDITIONAL EXPLANATIONS!\n"
            'If question is unrealated to python documentation, just answer "Question is unrelated"\n'
            "\nContext:\n"
            f"{'\n\n'.join([f'{idx + 1}. {c}' for idx, c in enumerate(context)])}\n"
            f"\nQuestion: {question}\n"
            f"\nResponse (with reference to the source [1-{k}]):\n"
        )
        messages = [{"role": "user", "content": prompt}]

        yield (
            json.dumps(
                {
                    "type": "proposals",
                    "data": [{"document": x[0], "score": x[1]} for x in res],
                }
            )
            + "\n\n"
        )

        try:
            response = self.client.chat.completions.create(
                model=model,
                messages=messages,
                stream=True,
                verbose=False,
                max_tokens=200,
            )

            for chunk in response:
                if chunk.choices[0].delta.content:
                    yield (
                        json.dumps(
                            {"type": "chunk", "data": chunk.choices[0].delta.content}
                        )
                        + "\n\n"
                    )
        except BaseException as e:
            yield json.dumps({"type": "error", "data": str(e)}) + "\n\n"

    def _retrive_docs(self, query_m: str, k):
        results, distances = ball_tree.query(query_m, k, return_distances=True)
        print(results, distances)

        contents = []

        for name in results:
            content = load(os.path.join(self.folder_path, name + ".txt"))

            contents.append(name + "\n" + content)

        return [results, distances], contents

    def get_answer(self, question: str, model: str, local: bool = True, k: int = 10):
        _, context = self._retrive_docs(question, k)

        prompt = (
            "You're a Python expert. Suppose all the documentation information you know is provided in context section. "
            "Answer on the question as usual but take technical information only from context. "
            'If there is no answer in the context, say, "I can\'t find the answer in the Python documentation"\n'
            "Highlight cited passages or provide “show sources” toggles ONLY FROM CONTEXT\n"
            "NO PYTHON CODE EXAMPLES!\n"
            "NO PROMPT REPETITION IN ANSWER!\n"
            "NO EXAMPLES!\n"
            "NO ADDITIONAL EXPLANATIONS!\n"
            'If question is unrealated to python documentation, just answer "Question is unrelated"\n'
            "\nContext:\n"
            f"{'\n\n'.join([f'{idx + 1}. {c}' for idx, c in enumerate(context)])}\n"
            f"\nQuestion: {question}\n"
            f"\nResponse (with reference to the source [1-{k}]):\n"
        )
        messages = [{"role": "user", "content": prompt}]

        try:
            if not local:
                response = self.client.chat.completions.create(
                    model=model, messages=messages, web_search=False, stream=False
                )
                return response.choices[0].message.content, ""
            else:
                inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda")
                outputs = self.model.generate(
                    **inputs, max_new_tokens=200, use_cache=False
                )
                return self.tokenizer.decode(outputs[0], skip_special_tokens=True), ""
        except BaseException as e:
            return "", str(e)

In [4]:
rag_model = RAG()

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.81s/it]


In [5]:
rag_model.get_answer("How to find goth gf", model="evil", k=5, local=False)

['functools.singledispatchmethod', 'weakref.proxy', 'gc.enable', 'gc.collect', 'weakref.getweakrefcount'] [0.48139615 0.50007839 0.50009877 0.50047324 0.50269447]


('Question is unrelated', '')

In [10]:
print(rag_model.get_answer("sin", model="evil", k=5, local=False)[0])

['math.pow', 'cmath.sin', 'math.prod', 'math.perm', 'marshal.loads'] [0.40519656 0.44195419 0.44203956 0.46457629 0.46607111]
The sine of z is returned by the cmath.sin function.

Show sources: [2]


In [9]:
print(rag_model.get_answer("sin", model="evil", k=5, local=True)[0])

['math.pow', 'cmath.sin', 'math.prod', 'math.perm', 'marshal.loads'] [0.40519656 0.44195419 0.44203956 0.46457629 0.46607111]
You're a Python expert. Suppose all the documentation information you know is provided in context section. Answer on the question as usual but take technical information only from context. If there is no answer in the context, say, "I can't find the answer in the Python documentation"
Highlight cited passages or provide “show sources” toggles ONLY FROM CONTEXT
NO PYTHON CODE EXAMPLES!
NO PROMPT REPETITION IN ANSWER!
NO EXAMPLES!
NO ADDITIONAL EXPLANATIONS!
If question is unrealated to python documentation, just answer "Question is unrelated"

Context:
1. math.pow
FUNCTION

math.pow FROM math

PARAMETERS
x, y

DESCRIPTION
Return x raised to the power y.  Exceptional cases follow
the IEEE 754 standard as far as possible.  In particular,
pow(1.0, x) and pow(x, 0.0) always return 1.0, even
when x is a zero or a NaN.  If both x and y are finite,
x is negative, and y 

- command-r
- evil
- qwen-2-72b