In [7]:
%load_ext autoreload
%autoreload 2

import os
import sys
from dotenv import load_dotenv


# Do this to enable importing modules
src_path = os.path.join(os.path.abspath(""), "..")
sys.path.insert(0, src_path)

env_path = os.path.join(src_path, "feature_pipeline/.env")

load_dotenv(env_path)  # take environment variables from feature pipeline subfolder .env.

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


True

In [10]:
from feature_pipeline.llm_components.prompt_templates import QueryExpansionTemplate

In [83]:
from langchain_community.utils.openai_functions import (
    convert_pydantic_to_openai_function,
)
from langchain.prompts
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_openai import ChatOpenAI

## Function calling

In [67]:
class QueryMetadata(BaseModel):
    """Information to extract from the user query. 
    Dates should be transformed to yyyy-mm-dd and be relative to the given current date"""

    currency: str = Field(description="The cryptocurrency mentioned in the query.")
    date: str = Field(description="date from the text in the format yyyy-mm-dd")


openai_functions = [convert_pydantic_to_openai_function(QueryMetadata)]
openai_functions

In [87]:
from langchain.globals import set_debug
set_debug(False)
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from datetime import datetime
current_date = datetime.now().strftime(format="%Y-%m-%d")

template = ChatPromptTemplate.from_messages([
    ("system", "Today´s date is {current_date}."),
    ("human", "{user_query}"),
])

model = ChatOpenAI(temperature=0)
parser = JsonOutputFunctionsParser()
chain = template | model.bind(functions=openai_functions) | parser

query = "What was the price of btc 10 days ago?"
chain.invoke({"current_date":current_date,  "user_query": query}, config={"verbose":True})

{'currency': 'btc', 'date': '2024-05-11'}

In [80]:
# Same thing with functions

from openai import OpenAI

client = OpenAI(
  api_key=os.environ.get("OPENAI_API_KEY")
)

# Test the prompts with a small sample
messages = [
    {
        "role": "assistant",
        "content": f"Today´s date is {current_date}"
    },
    {
        "role": "user",
        "content": "What date was 5 days ago?"
    }
]

response = client.chat.completions.create(
    model="gpt-3.5-turbo",
    messages=messages,
    functions=openai_functions
)
response.choices[0].message

## Filtering with extracted metadata

In [None]:
self._qdrant_client.search(
      collection_name="vector_posts",
      query_filter=models.Filter(
          must=[
              models.FieldCondition(
                  key="author_id",
                  match=models.MatchValue(
                      value=metadata_filter_value,
                  ),
              )
          ]
      ),
      query_vector=self._embedder.encode(generated_query).tolist(),
      limit=k,
)

In [None]:
import concurrent.futures

import feature_pipeline.utils
from qdrant_client import QdrantClient, models
from feature_pipeline.rag.query_expanison import QueryExpansion
from feature_pipeline.rag.reranking import Reranker
from feature_pipeline.rag.self_query import SelfQuery
from sentence_transformers.SentenceTransformer import SentenceTransformer
from feature_pipeline.settings import settings

import feature_pipeline.logger_utils as logger_utils


logger = logger_utils.get_logger(__name__)


class VectorRetriever:
    """
    Class for retrieving vectors from a Vector store in a RAG system using query expansion and Multitenancy search.
    """

    def __init__(self, query: str):
        self._client = QdrantClient(
            host=settings.QDRANT_DATABASE_HOST, port=settings.QDRANT_DATABASE_PORT
        )
        self.query = query
        self._embedder = SentenceTransformer(settings.EMBEDDING_MODEL_ID)
        self._query_expander = QueryExpansion()
        self._metadata_extractor = SelfQuery()
        self._reranker = Reranker()

    def _search_single_query(
        self, generated_query: str, metadata_filter_value: str, k: int
    ):
        assert k > 3, "k should be greater than 3"

        query_vector = self._embedder.encode(generated_query).tolist()
        vectors = [
            self._client.search(
                collection_name="vector_posts",
                query_filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="author_id",
                            match=models.MatchValue(
                                value=metadata_filter_value,
                            ),
                        )
                    ]
                ),
                query_vector=query_vector,
                limit=k // 3,
            ),
            self._client.search(
                collection_name="vector_articles",
                query_filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="author_id",
                            match=models.MatchValue(
                                value=metadata_filter_value,
                            ),
                        )
                    ]
                ),
                query_vector=query_vector,
                limit=k // 3,
            ),
            self._client.search(
                collection_name="vector_repositories",
                query_filter=models.Filter(
                    must=[
                        models.FieldCondition(
                            key="owner_id",
                            match=models.MatchValue(
                                value=metadata_filter_value,
                            ),
                        )
                    ]
                ),
                query_vector=query_vector,
                limit=k // 3,
            ),
        ]

        return utils.flatten(vectors)

    def retrieve_top_k(self, k: int, to_expand_to_n_queries: int) -> list:
        generated_queries = self._query_expander.generate_response(
            self.query, to_expand_to_n=to_expand_to_n_queries
        )
        logger.info(
            "Successfully generated queries for search.",
            num_queries=len(generated_queries),
        )

        author_id = self._metadata_extractor.generate_response(self.query)
        logger.info(
            "Successfully extracted the author_id from the query.",
            author_id=author_id,
        )

        with concurrent.futures.ThreadPoolExecutor() as executor:
            search_tasks = [
                executor.submit(
                    self._search_single_query, query, author_id, k
                )
                for query in generated_queries
            ]

            hits = [
                task.result() for task in concurrent.futures.as_completed(search_tasks)
            ]
            hits = utils.flatten(hits)

        logger.info("All documents retrieved successfully.", num_documents=len(hits))

        return hits

    def rerank(self, hits: list, keep_top_k: int) -> list[str]:
        content_list = [hit.payload["content"] for hit in hits]
        rerank_hits = self._reranker.generate_response(
            query=self.query, passages=content_list, keep_top_k=keep_top_k
        )
        
        logger.info("Documents reranked successfully.", num_documents=len(rerank_hits))

        return rerank_hits

    def set_query(self, query: str):
        self.query = query
