In [None]:
import logging
import os
from http import HTTPStatus
from typing import Annotated
from uuid import UUID

from fastapi import Depends, FastAPI, HTTPException, WebSocket
from fastapi.encoders import jsonable_encoder
from langchain.chains.llm import LLMChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore
from elasticsearch import Elasticsearch

from core_api.src.auth import get_user_uuid
from redbox.llm.prompts.chat import (
    CONDENSE_QUESTION_PROMPT,
    STUFF_DOCUMENT_PROMPT,
    WITH_SOURCES_PROMPT,
)
from redbox.model_db import MODEL_PATH
from redbox.models import EmbeddingModelInfo, Settings
from redbox.models.chat import ChatRequest, ChatResponse, SourceDocument


env = Settings(_env_file="../.env")

embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder="../models/")

es = Elasticsearch(
    hosts=[
        {
            "host": "localhost",
            "port": env.elastic.port,
            "scheme": env.elastic.scheme,
        }
    ],
    basic_auth=(env.elastic.user, env.elastic.password),
)

if env.elastic.subscription_level == "basic":
    strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
    strategy = ApproxRetrievalStrategy(hybrid=True)

vector_store = ElasticsearchStore(
    es_connection=es,
    index_name="redbox-data-chunk",
    embedding=embedding_model,
    strategy=strategy,
    vector_query_field="embedding",
)

llm = ChatLiteLLM(
    model=env.openai_model,
    streaming=True,
    azure_key=env.azure_openai_api_key,
    api_version=env.openai_api_version,
    api_base=env.azure_openai_endpoint,
    max_tokens=4_096,
)

# Summarisation scratch

In [None]:
from langchain.schema import Document
from redbox.models.file import Metadata
from functools import reduce
from redbox.storage import ElasticsearchStorageHandler
from langchain_core.runnables import RunnableLambda
from langchain.schema import StrOutputParser
from redbox.llm.prompts.core import _core_redbox_prompt
from operator import itemgetter

In [None]:
storage_handler = ElasticsearchStorageHandler(es_client=es, root_index=env.elastic_root_index)

def get_file_as_documents(
    file_uuid: UUID,
    user_uuid: UUID,
    storage_handler: ElasticsearchStorageHandler = storage_handler,
    max_tokens: int | None = None
) -> list[Document]:
    """Gets a file as LangChain Documents, splitting it by max_tokens."""
    documents: list[Document] = []
    chunks_unsorted = storage_handler.get_file_chunks(parent_file_uuid=file_uuid, user_uuid=user_uuid)
    chunks = sorted(chunks_unsorted, key=lambda x: x.index)

    total_tokens = sum(chunk.token_count for chunk in chunks)
    print(total_tokens)

    token_count: int = 0
    n = max_tokens or float("inf")
    page_content: list[str] = []
    metadata: list[Metadata | None] = []

    for chunk in chunks:
        if token_count + chunk.token_count >= n:
            document = Document(
                page_content=" ".join(page_content),
                metadata=reduce(Metadata.merge, metadata),
            )
            documents.append(document)
            token_count = 0
            page_content = []
            metadata = []

        page_content.append(chunk.text)
        metadata.append(chunk.metadata)
        token_count += chunk.token_count

    if len(page_content) > 0:
        document = Document(
            page_content=" ".join(page_content),
            metadata=reduce(Metadata.merge, metadata),
        )
        documents.append(document)

    return documents


In [None]:
def summarise(
    chat_request: ChatRequest,
    file_uuid: UUID,
    user_uuid: UUID,
    llm: ChatLiteLLM,
    storage_handler: ElasticsearchStorageHandler,
) -> ChatResponse:
    question = chat_request.message_history[-1].text
    previous_history = list(chat_request.message_history[:-1])
    
    # get full doc from vector store
    documents = get_file_as_documents(
        file_uuid=file_uuid, 
        user_uuid=user_uuid, 
        storage_handler=storage_handler,
        max_tokens=20_000
    )
    
    # right now, can only handle a single document so we manually truncate
    document = documents[:1]
    if len(documents) > 1:
        print("Document was longer than 20k tokens. Truncating to the first 20k.")
    
    # stuff raw prompt
    chat_history = [
        ("system", _core_redbox_prompt),
        ("placeholder", "{messages}"),
        ("user", "Question: {question}. \n\n Content: \n\n<document> {content} </document> \n\n Answer: "),
    ]

    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    chain = (
        {
            "question": itemgetter("question"),
            "messages": itemgetter("messages"),
            "content": itemgetter("content") | RunnableLambda(format_docs),
        }
        | ChatPromptTemplate.from_messages(chat_history)
        | llm
        | StrOutputParser()
    )

    # return
    return chain.invoke(
        input={
            "question": question,
            "content": document,
            "messages": [(msg.role, msg.text) for msg in previous_history]
        }
    )


chat_request_body = {
    "message_history": [
        # {"text": "Can you always refer to BEIS as the Department for Business, Energy and Industrial Strategy from now on?", "role": "user"},
        # {"text": "Of course. In future responses I will always expand the BEIS acronym.", "role": "ai"},
        {"text": "Please summarise all the key people in this document and who they work for.", "role": "user"},
    ]
}

res = summarise(
    chat_request=ChatRequest(**chat_request_body),
    file_uuid=UUID("35b3d95f-7f65-4cae-b159-22001ca19c88"),
    user_uuid=UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"),
    llm=llm,
    storage_handler=storage_handler
)

print(res)