In [5]:
import logging
import os
from http import HTTPStatus
from typing import Annotated
from uuid import UUID

from fastapi import Depends, FastAPI, HTTPException, WebSocket
from fastapi.encoders import jsonable_encoder
from langchain.chains.llm import LLMChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore
from elasticsearch import Elasticsearch

from core_api.src.auth import get_user_uuid
from redbox.llm.prompts.chat import (
    CONDENSE_QUESTION_PROMPT,
    STUFF_DOCUMENT_PROMPT,
    WITH_SOURCES_PROMPT,
)
from redbox.model_db import MODEL_PATH
from redbox.models import EmbeddingModelInfo, Settings
from redbox.models.chat import ChatRequest, ChatResponse, SourceDocument
from redbox.storage import ElasticsearchStorageHandler

env = Settings(_env_file="../.env")
env.elastic.host = "localhost"
env.minio_host = "localhost"

embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder="../models/")

es = Elasticsearch(
    hosts=[
        {
            "host": "localhost",
            "port": env.elastic.port,
            "scheme": env.elastic.scheme,
        }
    ],
    basic_auth=(env.elastic.user, env.elastic.password),
)

if env.elastic.subscription_level == "basic":
    strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
    strategy = ApproxRetrievalStrategy(hybrid=True)

vector_store = ElasticsearchStore(
    es_connection=es,
    index_name="redbox-data-chunk",
    embedding=embedding_model,
    strategy=strategy,
    vector_query_field="embedding",
)

llm = ChatLiteLLM(
    model=env.azure_openai_model,
    streaming=True,
    azure_key=env.azure_openai_api_key,
    api_version=env.openai_api_version,
    api_base=env.azure_openai_endpoint,
    max_tokens=4_096,
)

storage_handler = ElasticsearchStorageHandler(es_client=es, root_index=env.elastic_root_index)

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps


# RAG scratch

In [4]:
from langchain.schema import Document
from redbox.models.file import Metadata
from functools import reduce
from langchain_core.runnables import RunnableLambda
from langchain.schema import StrOutputParser
from redbox.llm.prompts.core import _core_redbox_prompt
from operator import itemgetter

[See here for lots of ideas.](https://python.langchain.com/v0.1/docs/integrations/retrievers/elasticsearch_retriever/#custom-document-mapper)

In [89]:
from langchain_elasticsearch import ElasticsearchRetriever
from typing import Any
from redbox.models import Chunk
from operator import itemgetter
from typing import TypedDict

class ESQuery(TypedDict):
    question: str
    file_uuids: list[UUID] | None = None
    user_uuid: UUID

def es_query(query: ESQuery) -> dict[str, Any]:
    vector = embedding_model.embed_query(query["question"])
    search_kwargs = {
        "query": {
            "bool": {
                "must": [
                    { "match": { "text": query["question"]} }
                ],
                "filter": [
                    {
                        "term": { 
                            "creator_user_uuid.keyword":  str(query["user_uuid"])
                        }
                    }
                ]
            }
        },
        "knn": {
            "field": "embedding",
            "query_vector": vector,
            "k": 5,
            "num_candidates": 10,
        }
    }

    if query["file_uuids"] is not None:
        search_kwargs["query"]["bool"]["filter"].append(
            {
                "terms": {
                    "parent_file_uuid.keyword": [str(uuid) for uuid in query["file_uuids"]]
                }
            }
        )
    
    return search_kwargs

def chunk_mapper(hit: dict[str, Any]) -> Chunk:
    return Chunk(**hit["_source"])

retriever = ElasticsearchRetriever(
    es_client=es,
    index_name="redbox-data-chunk",
    body_func=es_query,
    document_mapper=chunk_mapper
)

uuids = [
    # UUID("718dfb9c-3f0c-4942-a0c1-e0458a7a53c6"), 
    UUID("a28c04e2-8a1c-41b0-8d29-74ae41aa2e0f")
]

retriever.invoke(
    input={
        "query": "tell me about energy",
        "file_uuids": uuids,
        "user_uuid": UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5")
    }
)
# retriever.invoke("tell me about energy")

KeyError: 'question'

In [79]:
from core_api.src.format import format_chunks

res = retriever.invoke(
    input={
        "query": "tell me about energy",
        "file_uuids": uuids,
        "user_uuid": UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5")
    }
)

format_chunks(chunks=res)

INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search [status:200 duration:0.015s]


'<Doca28c04e2-8a1c-41b0-8d29-74ae41aa2e0f>\n I hope that this briefing note continues to be of interest to you - please let me know if any changes are needed to the distribution list.\n\nThe next update will be issued on Thursday 27 October.\n\nKevin Harris Energy Statistics T: 0747 135 8194 E: kevin.harris@beis.gov.uk\n\nwww.gov.uk/beis | twitter.com/beis_stats\n\nenergy stats monthly brief September 2022.pdf 436K \n</Doca28c04e2-8a1c-41b0-8d29-74ae41aa2e0f>\n\n<Doca28c04e2-8a1c-41b0-8d29-74ae41aa2e0f>\n and we welcome any comments you may have about what you would like to see. At the moment it is only available for BEIS staff (i.e. owners of an @beis e-mail account), but we are looking to make it more widely available as we develop further. \n</Doca28c04e2-8a1c-41b0-8d29-74ae41aa2e0f>\n\n<Doca28c04e2-8a1c-41b0-8d29-74ae41aa2e0f>\n ---------- Forwarded message --------- From: Harris, Kevin (TIUA - Analysis Directorate) <Kevin.Harris@beis.gov.uk> Date: Thu, 29 Sept 2022 at 09:45 Subjec

In [93]:
from core_api.src.format import format_chunks
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter

chat_history = [
    ("system", "You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know."),
    ("placeholder", "{messages}"),
    ("user", "Question: {question}. \n\n Documents: \n\n {documents} \n\n Answer: "),
]

prompt = ChatPromptTemplate.from_messages(chat_history)

chain = (
    RunnablePassthrough()
    | {
        "question": itemgetter("question"),
        "messages": itemgetter("message_history"),
        "documents": retriever | format_chunks, 
        "sources": retriever,
    }
    | {
        "response": prompt | llm,
        "sources": itemgetter("sources"),
    }
)

In [95]:
chain.invoke(
    input={
        "question": "tell me about energy",
        "file_uuids": uuids,
        "user_uuid": UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"),
        "message_history": [
            ("user", "Can you always refer to BEIS as the Department for Business, Energy and Industrial Strategy from now on?"),
            ("ai", "Of course. In future responses I will always expand the BEIS acronym.")
        ]
    }
)

INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search [status:200 duration:0.011s]
INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search [status:200 duration:0.014s]
[92m07:37:43 - LiteLLM:INFO[0m: utils.py:1307 - [92m

POST Request Sent from LiteLLM:
curl -X POST \
https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/ \
-H 'Authorization: Bearer a87792838865********************' \
-d '{'model': 'gpt-4', 'messages': [{'role': 'system', 'content': "You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know."}, {'role': 'user', 'content': 'Can you always refer to BEIS as the Department for Business, Energy and Industrial Strategy from now on?'}, {'role': 'assistant', 'content': 'Of course. In future responses I will always expand the BEIS acronym.'}, {'role': 'user', 'conten

{'response': AIMessage(content='Energy consumption in the UK has varied due to several factors. For the period between May to July 2022 compared to the previous year, primary energy consumption in the UK increased by 2.3% due to easing of lockdown restrictions, with petroleum consumption particularly notable. If adjusted for temperature, the consumption rose by 3.9%. Indigenous energy production rose significantly by 21%, fueled by strong growth in UKCS production.\n\nThe UK has also played a key role in supplying gas to Europe as it transitions away from Russian gas, which has led to a noteworthy increase in gas exports. For the second quarter of 2022, total final energy consumption was slightly lower than that of 2021 due to warmer temperatures reducing demand. However, increased activity in the transport sector led to a 23% increase in energy consumption, bringing petrol and diesel consumption close to pre-pandemic levels.\n\nIn contrast, energy consumption in the service sector fel

In [69]:
chain

{'documents': ElasticsearchRetriever(es_client=<Elasticsearch(['http://localhost:9200'])>, index_name='redbox-data-chunk', body_func=<function query at 0x17b906700>, document_mapper=<function chunk_mapper at 0x17b907240>)
 | RunnableLambda(format_chunks),
 'sources': operator.itemgetter('sources'),
 'response': ChatPromptTemplate(input_variables=['documents', 'question'], input_types={'messages': typing.List[typing.Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}, partial_variables={'messages': []}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template="You're a helpful Q&A agent.")), MessagesPlaceholder(variable_name='messages', optional=True), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['documents', 'question']

In [70]:
previous_history = [
    {"text": "Can you always refer to BEIS as the Department for Business, Energy and Industrial Strategy from now on?", "role": "user"},
    {"text": "Of course. In future responses I will always expand the BEIS acronym.", "role": "ai"},
    # {"text": "Please summarise all the key people in this document and who they work for.", "role": "user"},
]

chain.invoke(
    input={
        "question": "tell me about energy",
        "file_uuids": uuids,
        "user_uuid": UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"),
        "message_history": [(msg.role, msg.text) for msg in previous_history]
    }
)

AttributeError: 'dict' object has no attribute 'invoke'

In [None]:
def summarise(
    chat_request: ChatRequest,
    file_uuid: UUID,
    user_uuid: UUID,
    llm: ChatLiteLLM,
    storage_handler: ElasticsearchStorageHandler,
) -> ChatResponse:
    question = chat_request.message_history[-1].text
    previous_history = list(chat_request.message_history[:-1])
    
    # get full doc from vector store
    documents = get_file_as_documents(
        file_uuid=file_uuid, 
        user_uuid=user_uuid, 
        storage_handler=storage_handler,
        max_tokens=20_000
    )
    
    # right now, can only handle a single document so we manually truncate
    document = documents[:1]
    if len(documents) > 1:
        print("Document was longer than 20k tokens. Truncating to the first 20k.")
    
    # stuff raw prompt
    chat_history = [
        ("system", _core_redbox_prompt),
        ("placeholder", "{messages}"),
        ("user", "Question: {question}. \n\n Content: \n\n<document> {content} </document> \n\n Answer: "),
    ]

    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    chain = (
        {
            "question": itemgetter("question"),
            "messages": itemgetter("messages"),
            "content": itemgetter("content") | RunnableLambda(format_docs),
        }
        | ChatPromptTemplate.from_messages(chat_history)
        | llm
        | StrOutputParser()
    )

    # return
    return chain.invoke(
        input={
            "question": question,
            "content": document,
            "messages": [(msg.role, msg.text) for msg in previous_history]
        }
    )


chat_request_body = {
    "message_history": [
        # {"text": "Can you always refer to BEIS as the Department for Business, Energy and Industrial Strategy from now on?", "role": "user"},
        # {"text": "Of course. In future responses I will always expand the BEIS acronym.", "role": "ai"},
        {"text": "Please summarise all the key people in this document and who they work for.", "role": "user"},
    ]
}

res = summarise(
    chat_request=ChatRequest(**chat_request_body),
    file_uuid=UUID("35b3d95f-7f65-4cae-b159-22001ca19c88"),
    user_uuid=UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"),
    llm=llm,
    storage_handler=storage_handler
)

print(res)

In [19]:
from redbox.models.file import Metadata, Chunk
from functools import partial, reduce
from uuid import UUID

chunks_unsorted = storage_handler.get_file_chunks(
    parent_file_uuid=UUID("35b3d95f-7f65-4cae-b159-22001ca19c88"), 
    user_uuid=UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5")
)
chunks = sorted(chunks_unsorted, key=lambda x: x.index)

def reduce_chunks_by_tokens(chunks: list[Chunk] | None, chunk: Chunk, max_tokens: int) -> list[Chunk]:
    """"""
    if not chunks:
        return [chunk]
    
    last_chunk = chunks[-1]

    if chunk.token_count + last_chunk.token_count <= max_tokens:
        chunks[-1] = Chunk(
            parent_file_uuid=last_chunk.parent_file_uuid,
            index=last_chunk.index,
            text=last_chunk.text + chunk.text,
            metadata=Metadata.merge(last_chunk.metadata, chunk.metadata),
            creator_user_uuid=last_chunk.creator_user_uuid,
        )
    else:
        chunk.index = last_chunk.index + 1
        chunks.append(chunk)
    
    return chunks

reduce_chunk_t300 = partial(reduce_chunks_by_tokens, max_tokens=300)

result = reduce(lambda cs, c: reduce_chunk_t300(cs, c), chunks, [])

INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search?scroll=5m [status:200 duration:0.089s]
INFO:elastic_transport.transport:POST http://localhost:9200/_search/scroll [status:200 duration:0.102s]
INFO:elastic_transport.transport:POST http://localhost:9200/_search/scroll [status:200 duration:0.008s]
INFO:elastic_transport.transport:POST http://localhost:9200/_search/scroll [status:200 duration:0.003s]
INFO:elastic_transport.transport:DELETE http://localhost:9200/_search/scroll [status:200 duration:0.002s]


In [16]:
import numpy as np
len(chunks), max(chunk.token_count for chunk in chunks), min(chunk.token_count for chunk in chunks), np.mean([chunk.token_count for chunk in chunks])

(2066, 301, 1, 84.15295256534365)

In [21]:
import numpy as np
len(result), max(chunk.token_count for chunk in result), min(chunk.token_count for chunk in result), np.mean([chunk.token_count for chunk in result])

(685, 301, 111, 253.76058394160583)

In [20]:
[chunk.index for chunk in result]

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,
