In [None]:
from g4f.client import Client

In [2]:
import torch
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class BERTEmbedder:
    def __init__(self, model_name="sentence-transformers/stsb-bert-base", device=None):
        self.device = device or "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

        self.do_lower_case = getattr(self.tokenizer, "do_lower_case", False)

    def text_to_embedding(self, texts, pooling="mean", normalize=False):
        is_single = isinstance(texts, str)
        texts = [texts] if is_single else texts

        inputs = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True, max_length=128
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        if pooling == "mean":
            mask = inputs["attention_mask"].unsqueeze(-1)
            embeddings = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(
                min=1e-9
            )
        elif pooling == "cls":
            embeddings = outputs.last_hidden_state[:, 0, :]
        else:
            raise ValueError("Invalid pooling method")

        if normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings.cpu().numpy()[0] if is_single else embeddings.cpu().numpy()

### ADD INDEXER/

In [None]:
class RAG:
    indexer = None
    embeder = BERTEmbedder()
    docs = None

    def __init__(self, model: str):
        self.client = Client()
        self.model = model

    def _retrive_docs(self, query: str, k: int = 10):
        query_embeding = self.embeder.text_to_embedding(query)

        distances, indices = self.indexer.search(query_embeding, k)

        return [self.docs[i] for i in indices[0]]

    def get_answer(self, question: str, k: int = 10):
        context = self._retrive_docs(question)

        prompt = f"""You're a Python expert. Answer strictly according to the documentation wich is marked as 'Context' below.  
                If there is no answer in the context, say, "I can't find the answer in the Python documentation
                Context:
                {context}
                Question: {question}

                Response (with reference to the source [1-{k}]):"""

        messages = [{"role": "user", "content": question}]

        response = self.client.chat.completions.create(
            model=self.model, messages=messages, web_search=False
        )

        return response.choices[0].message.content

In [None]:
rag_model = RAG("gpt4gpt-4o-mini")

question = "Sin and Cos"

rag_model.get_answer(question=question)