In [None]:
from uuid import UUID
from pathlib import Path
import tiktoken
import os
import logging
import sys

from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore
from elasticsearch import Elasticsearch

from redbox.models import Settings
from redbox.models.settings import ElasticLocalSettings
from redbox.storage import ElasticsearchStorageHandler

from core_api.callbacks import LoggerCallbackHandler

from dotenv import find_dotenv, load_dotenv

ROOT = Path().resolve().parent

_ = load_dotenv(find_dotenv(ROOT / '.env'))

logging.basicConfig(steam=sys.stdout, level=logging.INFO)
log = logging.getLogger()

env = Settings(
    _env_file=(ROOT / '.env'),
    minio_host="localhost", 
    object_store="minio",
    elastic=ElasticLocalSettings(host="localhost"),
)

embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder="../models/")

es = Elasticsearch(
    hosts=[
        {
            "host": "localhost",
            "port": env.elastic.port,
            "scheme": env.elastic.scheme,
        }
    ],
    basic_auth=(env.elastic.user, env.elastic.password),
)

if env.elastic.subscription_level == "basic":
    strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
    strategy = ApproxRetrievalStrategy(hybrid=True)

vector_store = ElasticsearchStore(
    es_connection=es,
    index_name="redbox-data-chunk",
    embedding=embedding_model,
    strategy=strategy,
    vector_query_field="embedding",
)

# See core_api.dependecies for details on this hack
os.environ["AZURE_API_VERSION"] = env.openai_api_version

logger_callback = LoggerCallbackHandler(logger=log)

llm = ChatLiteLLM(
    model=env.azure_openai_model,
    streaming=True,
    azure_key=env.azure_openai_api_key,
    api_base=env.azure_openai_endpoint,
    max_tokens=1_024,
    callbacks=[logger_callback]
)

storage_handler = ElasticsearchStorageHandler(es_client=es, root_index=env.elastic_root_index)

tokeniser = tiktoken.get_encoding("cl100k_base")

# RAG scratch

In [None]:
from core_api.retriever import ParameterisedElasticsearchRetriever
from langchain_core.runnables import ConfigurableField

def get_parameterised_retriever(
    env, 
    es
):
    """Creates an Elasticsearch retriever runnable.

    Runnable takes input of a dict keyed to question, file_uuids and user_uuid.

    Runnable returns a list of Chunks.
    """
    default_params = {
        "size": env.ai.rag_k,
        "num_candidates": env.ai.rag_num_candidates,
        "match_boost": 1,
        "knn_boost": 1,
        "similarity_threshold": 0,
    }
    return ParameterisedElasticsearchRetriever(
        es_client=es,
        index_name=f"{env.elastic_root_index}-chunk",
        params=default_params,
        embedding_model=embedding_model,
    ).configurable_fields(
        params=ConfigurableField(
            id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever."
        )
    )

retriever = get_parameterised_retriever(env, es)

In [None]:
retriever.invoke(
    input={
        "question": "KAN",
        "file_uuids": [
            "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
            # "1a9d18a7-9499-47b6-abcc-4e82370028ee" # MAMBA paper
        ],
        "user_uuid": "5c37bf4c-002c-458d-9e68-03042f76a5b1"
    }
)

In [15]:
from langchain_core.runnables import (
    Runnable,
    RunnableLambda,
    RunnablePassthrough,
    chain,
)
from langchain.schema import StrOutputParser
from operator import itemgetter
from redbox.models import ChatRoute
from redbox.models.chain import ChainInput

from core_api.format import format_documents
from core_api.runnables import make_chat_prompt_from_messages_runnable


def build_retrieval_chain(
    llm,
    retriever,
    tokeniser,
    env,
) -> Runnable:
    return (
        RunnablePassthrough.assign(documents=retriever)
        | RunnablePassthrough.assign(
            formatted_documents=(RunnablePassthrough() | itemgetter("documents") | format_documents)
        )
        | {
            "response": make_chat_prompt_from_messages_runnable(
                system_prompt=env.ai.retrieval_system_prompt,
                question_prompt=env.ai.retrieval_question_prompt,
                input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
                tokeniser=tokeniser,
            )
            | llm
            | StrOutputParser(),
            "source_documents": itemgetter("documents"),
            "route_name": RunnableLambda(lambda _: ChatRoute.search.value),
        }
    )

rag = build_retrieval_chain(llm, retriever, tokeniser, env)

params = ChainInput(
    question="Give the full citation.",
    file_uuids=[
        "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
        "1a9d18a7-9499-47b6-abcc-4e82370028ee" # MAMBA paper
    ],
    user_uuid="5c37bf4c-002c-458d-9e68-03042f76a5b1",
    chat_history=[
        {"text": "What is the fastest attention that the authors are aware of?", "role": "user"},
        {"text": "The fastest implementation of attention, according to the authors, is **FlashAttention-2 (Dao 2024)** with a causal mask. It's stated that this version of FlashAttention-2 is approximately **1.7× faster** than the version without a causal mask because roughly half of the attention entries are computed.", "role": "ai"},
    ],
)

rag.invoke(params.model_dump())

INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search [status:200 duration:0.018s]
INFO:httpx:HTTP Request: POST https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/chat/completions?api-version=2024-02-01 "HTTP/1.1 200 OK"
INFO:root:LLM end: generations=[[ChatGeneration(text="I'm sorry, but I can't provide the full citation as it's not given in the provided document excerpts.", message=AIMessage(content="I'm sorry, but I can't provide the full citation as it's not given in the provided document excerpts.", id='run-2263eeba-8e71-41fc-aee5-860ab1587e32-0'))]] llm_output=None run=None


{'response': "I'm sorry, but I can't provide the full citation as it's not given in the provided document excerpts.",
 'source_documents': [Document(page_content='Details of the fused kernel and recomputation are in Appendix D. The full Selective SSM layer and algorithm is illustrated in Figure 1.', metadata={'parent_file_uuid': '1a9d18a7-9499-47b6-abcc-4e82370028ee', 'creator_user_uuid': '5c37bf4c-002c-458d-9e68-03042f76a5b1', 'index': 120, 'page_number': 7, 'languages': ['eng'], 'link_texts': None, 'link_urls': None, 'links': None, 'created_datetime': '2024-06-28T07:23:52.885216+00:00', 'token_count': 29}),
  Document(page_content='The two challenges are the sequential nature of recurrence, and the large memory usage. To address the latter, just like the convolutional mode, we can attempt to not actually materialize the full state ℎ.', metadata={'parent_file_uuid': '1a9d18a7-9499-47b6-abcc-4e82370028ee', 'creator_user_uuid': '5c37bf4c-002c-458d-9e68-03042f76a5b1', 'index': 112, 'page

In [14]:
from langchain_core.runnables import (
    Runnable,
    RunnableLambda,
    RunnablePassthrough,
    chain,
)
from langchain.schema import StrOutputParser
from operator import itemgetter
from redbox.models import ChatRoute
from redbox.models.chain import ChainInput

from core_api.format import format_documents
from core_api.runnables import make_chat_prompt_from_messages_runnable

CONDENSE_SYSTEM_PROMPT = (
    "Given the following conversation and a follow up question, generate a follow "
    "up question to be a standalone question. "
    "You are only allowed to generate one question in response. "
    "Include sources from the chat history in the standalone question created, "
    "when they are available. "
    "If you don't know the answer, just say that you don't know, "
    "don't try to make up an answer. \n"
)

CONDENSE_QUESTION_PROMPT= "{question}\n=========\n Standalone question: "


def build_condense_retrieval_chain(
    llm,
    retriever,
    tokeniser,
    env,
) -> Runnable:
    
    def route(input_dict: dict):
        if len(input_dict["chat_history"]) > 0:
            return RunnablePassthrough.assign(
                question=make_chat_prompt_from_messages_runnable(
                    system_prompt=env.ai.condense_system_prompt,
                    question_prompt=env.ai.condense_question_prompt,
                    input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
                    tokeniser=tokeniser,
                )
                | llm
                | StrOutputParser()
            )
        else:
            return RunnablePassthrough()

    return (
        RunnableLambda(route)
        | RunnablePassthrough.assign(documents=retriever)
        | RunnablePassthrough.assign(
            formatted_documents=(RunnablePassthrough() | itemgetter("documents") | format_documents)
        )
        | {
            "response": make_chat_prompt_from_messages_runnable(
                system_prompt=env.ai.retrieval_system_prompt,
                question_prompt=env.ai.retrieval_question_prompt,
                input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
                tokeniser=tokeniser,
            )
            | llm
            | StrOutputParser(),
            "source_documents": itemgetter("documents"),
            "route_name": RunnableLambda(lambda _: ChatRoute.search.value),
        }
    )

# crag = make_chat_prompt_from_messages_runnable(
#     system_prompt=CONDENSE_SYSTEM_PROMPT,
#     question_prompt=CONDENSE_QUESTION_PROMPT,
#     input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
#     tokeniser=tokeniser,
# ) | llm

crag = build_condense_retrieval_chain(llm, retriever, tokeniser, env)

params = ChainInput(
    question="Give the full citation.",
    file_uuids=[
        "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
        "1a9d18a7-9499-47b6-abcc-4e82370028ee" # MAMBA paper
    ],
    user_uuid="5c37bf4c-002c-458d-9e68-03042f76a5b1",
    chat_history=[
        {"text": "What is the fastest attention that the authors are aware of?", "role": "user"},
        {"text": "The fastest implementation of attention, according to the authors, is **FlashAttention-2 (Dao 2024)** with a causal mask. It's stated that this version of FlashAttention-2 is approximately **1.7× faster** than the version without a causal mask because roughly half of the attention entries are computed.", "role": "ai"},
    ],
)

crag.invoke(params.model_dump())

INFO:httpx:HTTP Request: POST https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/chat/completions?api-version=2024-02-01 "HTTP/1.1 200 OK"
INFO:root:LLM end: generations=[[ChatGeneration(text="According to the author's knowledge, which version of FlashAttention is considered the fastest implementation of attention?", message=AIMessage(content="According to the author's knowledge, which version of FlashAttention is considered the fastest implementation of attention?", id='run-5b25fb65-c689-41b6-a6f9-dd755590fabd-0'))]] llm_output=None run=None
INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search [status:200 duration:0.030s]
INFO:httpx:HTTP Request: POST https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/chat/completions?api-version=2024-02-01 "HTTP/1.1 200 OK"
INFO:root:LLM end: generations=[[ChatGeneration(text="The speediest implementation of attention as per the authors' knowledge is **Flas

{'response': "The speediest implementation of attention as per the authors' knowledge is **FlashAttention-2 (Dao 2024) with a causal mask**. It is reported that it's around **1.7× faster** in comparison to the version without a causal mask.",
 'source_documents': [Document(page_content='For attention, we compare against the fastest implementation that we are aware of (FlashAttention-2 (Dao 2024)), with causal mask. Note that FlashAttention-2 with causal mask is about 1.7× faster than without causal mask, since approximately only half of the attention entries are computed.', metadata={'parent_file_uuid': '1a9d18a7-9499-47b6-abcc-4e82370028ee', 'creator_user_uuid': '5c37bf4c-002c-458d-9e68-03042f76a5b1', 'index': 647, 'page_number': 36, 'languages': ['eng'], 'link_texts': None, 'link_urls': None, 'links': None, 'created_datetime': '2024-06-28T07:23:52.921093+00:00', 'token_count': 62}),
  Document(page_content='memory-efficient. We evaluate the speed of our scan implementation compared t