In [None]:
import logging
import os
from http import HTTPStatus
from typing import Annotated
from uuid import UUID

from fastapi import Depends, FastAPI, HTTPException, WebSocket
from fastapi.encoders import jsonable_encoder
from langchain.chains.llm import LLMChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore
from elasticsearch import Elasticsearch

from core_api.src.auth import get_user_uuid
from redbox.llm.prompts.chat import (
    CONDENSE_QUESTION_PROMPT,
    STUFF_DOCUMENT_PROMPT,
    WITH_SOURCES_PROMPT,
)
from redbox.model_db import MODEL_PATH
from redbox.models import EmbeddingModelInfo, Settings
from redbox.models.chat import ChatRequest, ChatResponse, SourceDocument


env = Settings(_env_file="../.env")

embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder="../models/")

es = Elasticsearch(
    hosts=[
        {
            "host": "localhost",
            "port": env.elastic.port,
            "scheme": env.elastic.scheme,
        }
    ],
    basic_auth=(env.elastic.user, env.elastic.password),
)

if env.elastic.subscription_level == "basic":
    strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
    strategy = ApproxRetrievalStrategy(hybrid=True)

vector_store = ElasticsearchStore(
    es_connection=es,
    index_name="redbox-data-chunk",
    embedding=embedding_model,
    strategy=strategy,
    vector_query_field="embedding",
)

llm = ChatLiteLLM(
    model=env.openai_model,
    streaming=True,
    azure_key=env.azure_openai_api_key,
    api_version=env.openai_api_version,
    api_base=env.azure_openai_endpoint,
    max_tokens=4_096,
)

# Summarisation scratch

In [None]:
from langchain_core.runnables import RunnablePassthrough
    
# setup = RunnablePassthrough.assign(documents=context | format_docs, sources=context)
#     runnable = setup | {
#         "response": prompt | llm,
#         "sources": itemgetter("sources"),
#     }

In [None]:
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.prompt_template import format_document
from redbox.llm.prompts.core import _core_redbox_prompt

doc_prompt = PromptTemplate.from_template("{page_content}")

chat_history = [
    ("system", _core_redbox_prompt),
    # ("placeholder", "{messages}")
    ("human", "Summarize the following content:\n\n{content}"),
]

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

chain = (
    # {
    #     "content": lambda docs: "\n\n".join(
    #         format_document(doc, doc_prompt) for doc in docs
    #     )
    # }
    RunnablePassthrough.assign(content=ChatPromptTemplate.from_messages(chat_history) | format_docs)
    | llm
    | StrOutputParser()
)

In [None]:
conversation = [
    ("human", "Can you always refer to BEIS as the Department for Business, Energy and Industrial Strategy from now on?"),
    ("ai", "Of course. In future responses I will always expand the BEIS acronym.")
]

chain.invoke(
    input={
        "content": docs,
        "messages": conversation
    }
)
chain.invoke(
    input={
        "content": docs,
        "messages": conversation
    }
)

In [None]:
dir(chain)
# chain.input_schema
chain.to_json()

In [None]:
docs = get_file_as_documents(file_uuid=UUID("d6cdd5a8-5ea5-4f7b-b09c-80d4bb64c266"), user_uuid=UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"), max_tokens=1_000)

res = chain.invoke({"content": docs})

print(res)

In [None]:
def get_summary_runnable(
    llm: ChatLiteLLM,
    init_messages: Optional[list[ChatMessage]] = None,
) -> RunnableWithMessageHistory:
    if init_messages is None:
        init_messages = [
            ("system", _core_redbox_prompt),
            ("human", _with_sources_template),
        ]

    prompt = ChatPromptTemplate.from_messages(init_messages)
    context = itemgetter("report_request") | retriever
    setup = RunnablePassthrough.assign(documents=context | format_docs, sources=context)
    runnable = setup | {
        "response": prompt | llm,
        "sources": itemgetter("sources"),
    }

    return runnable

In [None]:
from langchain.schema import Document
from redbox.models.file import Metadata
from functools import reduce
from redbox.storage import ElasticsearchStorageHandler

storage_handler = ElasticsearchStorageHandler(es_client=es, root_index=env.elastic_root_index)

def get_file_as_documents(
    file_uuid: UUID,
    user_uuid: UUID,
    storage_handler: ElasticsearchStorageHandler = storage_handler,
    max_tokens: int | None = None
) -> list[Document]:
    """Gets a file as LangChain Documents, splitting it by max_tokens."""
    documents: list[Document] = []
    chunks_unsorted = storage_handler.get_file_chunks(parent_file_uuid=file_uuid, user_uuid=user_uuid)
    chunks = sorted(chunks_unsorted, key=lambda x: x.index)

    token_count: int = 0
    n = max_tokens or float("inf")
    page_content: list[str] = []
    metadata: list[Metadata | None] = []

    for chunk in chunks:
        if token_count + chunk.token_count >= n:
            document = Document(
                page_content=" ".join(page_content),
                metadata=reduce(Metadata.merge, metadata),
            )
            documents.append(document)
            token_count = 0
            page_content = []
            metadata = []

        page_content.append(chunk.text)
        metadata.append(chunk.metadata)
        token_count += chunk.token_count

    if len(page_content) > 0:
        document = Document(
            page_content=" ".join(page_content),
            metadata=reduce(Metadata.merge, metadata),
        )
        documents.append(document)

    return documents

get_file_as_documents(file_uuid=UUID("d6cdd5a8-5ea5-4f7b-b09c-80d4bb64c266"), user_uuid=UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"))

In [None]:
from elasticsearch.helpers import scan


chat_request_body = {
    "message_history": [
        {"text": "You are a helpful AI Assistant", "role": "system"},
        {"text": "What is AI?", "role": "user"},
    ]
}

chat_request = ChatRequest(**chat_request_body)

def summarise(
    chat_request: ChatRequest,
    file_uuid: UUID,
    user_uuid: UUID = UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"),
    llm: ChatLiteLLM = llm,
    storage_handler: ElasticsearchStorageHandler = storage_handler,
) -> ChatResponse:
    # get full doc from vector store
    document = get_file_as_documents(file_uuid=file_uuid, user_uuid=user_uuid, storage_handler=storage_handler)
    # stuff raw vanilla prompt

    # return

x = summarise(
    chat_request=chat_request,
    file_uuid=UUID("d6cdd5a8-5ea5-4f7b-b09c-80d4bb64c266")
)

In [None]:
x