In [22]:
import os

import torch
from sklearn.neighbors import BallTree
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from utils import from_current_file, load

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DEVICE

device(type='cuda')

In [23]:
class BERTEmbedder:
    def __init__(
        self, model_name="sentence-transformers/msmarco-bert-base-dot-v5", device=DEVICE
    ):
        self.device = device or "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

        self.do_lower_case = getattr(self.tokenizer, "do_lower_case", False)

    def text_to_embedding(self, texts, pooling="mean", normalize=False):
        is_single = isinstance(texts, str)
        texts = [texts] if is_single else texts

        inputs = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True, max_length=512
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        if pooling == "mean":
            mask = inputs["attention_mask"].unsqueeze(-1)
            embeddings = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(
                min=1e-9
            )
        elif pooling == "cls":
            embeddings = outputs.last_hidden_state[:, 0, :]
        else:
            raise ValueError("Invalid pooling method")

        if normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings.cpu().numpy()[0] if is_single else embeddings.cpu().numpy()


def batch(iterable, batch_size):
    all_batches = []
    current_batch = []

    for item in iterable:
        current_batch.append(item)
        if len(current_batch) == batch_size:
            all_batches.append(current_batch)
            current_batch = []

    if current_batch:  # Add the last partial batch
        all_batches.append(current_batch)

    return all_batches


class BERTBallTree:
    def __init__(
        self,
        index_dir: str = "../data/embedding_directory",
        documents_dir: str = "../data/scrapped/class_data_function__1_1",
        embedder=None,
        metric="euclidean",
        leaf_size=1,
        tree_name="tree",
        force=False,
    ):
        self._index_dir = from_current_file(index_dir)
        self._documents_dir = from_current_file(documents_dir)
        self.embedder = embedder or BERTEmbedder()
        self.metric = metric
        self.leaf_size = leaf_size
        self.tree = None
        self.tree_name = tree_name
        self.documents: dict[int, str] = {}

        if force or not os.path.exists(self._index_dir):
            print("Tree is not found, creating new...")
            if force:
                try:
                    shutil.rmtree(self._index_dir)
                except FileNotFoundError:
                    pass
            os.mkdir(path=self._index_dir)
            self.build_tree()
            print("Complete!")

        self.load_tree()

    def build_tree(self):
        sentences = []
        for document_id, filename in enumerate(os.listdir(self._documents_dir)):
            if filename.endswith(".txt"):
                with open(
                    os.path.join(self._documents_dir, filename), "r", encoding="utf-8"
                ) as f:
                    text = f.read()
                    self.documents[document_id] = filename[:-4]
                    sentences.append(text)

        embeddings = []
        for b in tqdm(batch(sentences, 64)):
            batch_embeddings = self.embedder.text_to_embedding(
                b, pooling="mean", normalize=True
            )
            embeddings.extend(batch_embeddings)
        self.tree = BallTree(embeddings, metric=self.metric, leaf_size=self.leaf_size)
        self.save_tree()

    def query(self, query_text, k=5, return_distances=False):
        """
        Query the Ball Tree for nearest neighbors.

        Args:
            query_text: The query text string
            k: Number of nearest neighbors to return
            return_distances: Whether to return distances along with results

        Returns:
            If return_distances is False: list of nearest texts
            If return_distances is True: tuple of (texts, distances)
        """
        if self.tree is None:
            raise ValueError("Ball Tree has not been built yet. Call build_tree() first.")

        # Get embedding for the query text
        query_embedding = self.embedder.text_to_embedding(
            query_text, pooling="mean", normalize=True
        ).reshape(1, -1)

        # Query the tree
        distances, indices = self.tree.query(query_embedding, k=k)

        # Get the corresponding texts
        results = []
        for indice in indices[0]:
            results.append(self.documents[indice])

        if return_distances:
            return results, distances[0]
        return results

    def save_tree(self):
        """Save the Ball Tree and associated data to disk."""
        import joblib

        data = {
            "tree": self.tree,
            "documents": self.documents,
        }
        joblib.dump(data, os.path.join(self._index_dir, self.tree_name))

    def load_tree(self):
        """Load a saved Ball Tree from disk."""
        import joblib

        data = joblib.load(os.path.join(self._index_dir, self.tree_name))
        self.tree = data["tree"]
        self.documents = data["documents"]


# Initialize and build the tree
ball_tree = BERTBallTree()

# Query the tree
# query = "sin"
# results, distances = ball_tree.query(query, k=10, return_distances=True)

# print(f"Query: {query}")
# print("Top 3 results:")
# for text, dist in zip(results, distances):
# print(f"- {text} (distance: {dist:.4f})")

### ADD INDEXER/

In [47]:
import json

import torch
from g4f.client import Client


class RAG:
    folder_path = "../data/scrapped/class_data_function__1_1"

    def __init__(self):
        self.client = Client()
        # model_id = "microsoft/Phi-3-mini-4k-instruct"
        # self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        # self.model = AutoModelForCausalLM.from_pretrained(
        #     model_id,
        #     device_map="cuda",
        #     torch_dtype="auto",
        #     trust_remote_code=True
        # )

    def generate_stream(self, question: str, model: str, k: int = 10):
        res, context = self._retrive_docs(question, k)
        prompt = (
            "You're a Python expert. Suppose all the documentation information you know is provided in context section. "
            "Answer on the question as usual but take technical information only from context. "
            'If there is no answer in the context, say, "I can\'t find the answer in the Python documentation"\n'
            "Highlight cited passages or provide “show sources” toggles ONLY FROM CONTEXT\n"
            "NO PYTHON CODE EXAMPLES!\n"
            "NO PROMPT REPETITION IN ANSWER!\n"
            "NO EXAMPLES!\n"
            "NO ADDITIONAL EXPLANATIONS!\n"
            'If question is unrealated to python documentation, just answer "Question is unrelated"\n'
            "\nContext:\n"
            f"{'\n\n'.join([f'{idx + 1}. {c}' for idx, c in enumerate(context)])}\n"
            f"\nQuestion: {question}\n"
            f"\nResponse (with reference to the source [1-{k}]):\n"
        )
        messages = [{"role": "user", "content": prompt}]

        yield (
            json.dumps(
                {
                    "type": "proposals",
                    "data": [{"document": x[0], "score": x[1]} for x in res],
                }
            )
            + "\n\n"
        )

        try:
            response = self.client.chat.completions.create(
                model=model,
                messages=messages,
                stream=True,
                verbose=False,
                max_tokens=200,
            )

            for chunk in response:
                if chunk.choices[0].delta.content:
                    yield (
                        json.dumps(
                            {"type": "chunk", "data": chunk.choices[0].delta.content}
                        )
                        + "\n\n"
                    )
        except BaseException as e:
            yield json.dumps({"type": "error", "data": str(e)}) + "\n\n"

    def _retrive_docs(self, query_m: str, k):
        results, distances = ball_tree.query(query_m, k, return_distances=True)
        print(results, distances)

        contents = []

        for name in results:
            content = load(os.path.join(self.folder_path, name + ".txt"))

            contents.append(name + "\n" + content)

        return [results, distances], contents

    def get_answer(self, question: str, model: str, local: bool = True, k: int = 10):
        _, context = self._retrive_docs(question, k)

        prompt = (
            "You're a Python expert. Suppose all the documentation information you know is provided in context section. "
            "Answer on the question as usual but take technical information only from context. "
            'If there is no answer in the context, say, "I can\'t find the answer in the Python documentation"\n'
            "Highlight cited passages or provide “show sources” toggles ONLY FROM CONTEXT\n"
            "NO PYTHON CODE EXAMPLES!\n"
            "NO PROMPT REPETITION IN ANSWER!\n"
            "NO EXAMPLES!\n"
            "NO ADDITIONAL EXPLANATIONS!\n"
            'If question is unrealated to python documentation, just answer "Question is unrelated"\n'
            "\nContext:\n"
            f"{'\n\n'.join([f'{idx + 1}. {c}' for idx, c in enumerate(context)])}\n"
            f"\nQuestion: {question}\n"
            f"\nResponse (with reference to the source [1-{k}]):\n"
        )
        # print(prompt)
        messages = [{"role": "user", "content": prompt}]

        try:
            if not local:
                response = self.client.chat.completions.create(
                    model=model, messages=messages, web_search=False, stream=False
                )
                return response.choices[0].message.content, ""
            else:
                inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda")
                outputs = self.model.generate(
                    **inputs, max_new_tokens=200, use_cache=False
                )
                return self.tokenizer.decode(outputs[0], skip_special_tokens=True), ""
        except BaseException as e:
            return "", str(e)

In [48]:
rag_model = RAG()
question = "How to find goth"

# res, err = rag_model.get_answer(
#     question=question, model="evil", client=client, local=False, k=3
# )
# print(res)
# print(err)

In [56]:
for res in rag_model.generate_stream("Write bublesort algorithm", model="command-r", k=5):
    print(res)

['wave.Wave_write.setcomptype', 'wave.Wave_write.setsampwidth', 'wave.Wave_write.setframerate', 'wave.Wave_read', 'wave.Wave_read.setpos'] [0.47908269 0.49011799 0.49325316 0.49424956 0.49997578]
{"type": "proposals", "data": [{"document": "wave.Wave_write.setcomptype", "score": "wave.Wave_write.setsampwidth"}, {"document": 0.47908268763074147, "score": 0.4901179886101874}]}


{"type": "error", "data": ""}




In [53]:
rag_model.get_answer("turtle", model="qwen-2-72b", k=5, local=False)

['turtle.title', 'turtle.setup', 'turtle.shapetransform', 'turtle.resetscreen', 'turtle.forward'] [0.38961707 0.40827205 0.41424621 0.41622499 0.41691331]


("I can't find the answer in the Python documentation. However, the context provided includes several functions related to the turtle module: `turtle.title` [1], `turtle.setup` [2], `turtle.shapetransform` [3], `turtle.resetscreen` [4], and `turtle.forward` [5]. These functions allow you to manipulate the turtle graphics window, such as setting its title, size, and position, transforming the turtle shape, resetting the screen, and moving the turtle forward.",
 '')

- command-r
- evil
- qwen-2-72b

In [7]:
res

'You\'re a Python expert. Suppose all the documentation information you know is provided in context section. Answer on the question as usual but take technical information only from context. If there is no answer in the context, say, "I can\'t find the answer in the Python documentation"\nHighlight cited passages ONLY FROM CONTEXT\nNO PYTHON CODE EXAMPLES!\nNO PROMPT REPETITION IN ANSWER!\n\nContext:\n1. math.pow\nFUNCTION\n\nmath.pow FROM math\n\nPARAMETERS\nx, y\n\nDESCRIPTION\nReturn x raised to the power y.  Exceptional cases follow\nthe IEEE 754 standard as far as possible.  In particular,\npow(1.0, x) and pow(x, 0.0) always return 1.0, even\nwhen x is a zero or a NaN.  If both x and y are finite,\nx is negative, and y is not an integer then pow(x, y)\nis undefined, and raises ValueError.\nUnlike the built-in ** operator, math.pow() converts both\nits arguments to type float.  Use ** or the built-in\npow() function for computing exact integer powers.\nChanged in version 3.11: The 

In [8]:
# from transformers import AutoModelForCausalLM, AutoTokenizer
# import torch

# model_id = "microsoft/Phi-3-mini-4k-instruct"
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     device_map="cuda",
#     torch_dtype="auto",
#     trust_remote_code=True
# )

# # Disable cache (if not needed)
# inputs = tokenizer("Explain RAG.", return_tensors="pt").to("cuda")
# outputs = model.generate(**inputs, max_new_tokens=200, use_cache=False)  # ← Critical!
# print(tokenizer.decode(outputs[0], skip_special_tokens=True))