In [23]:
%load_ext autoreload
%autoreload 2

In [12]:
import logging
import os
from http import HTTPStatus
from typing import Annotated
from uuid import UUID

from fastapi import Depends, FastAPI, HTTPException, WebSocket
from fastapi.encoders import jsonable_encoder
from langchain.chains.llm import LLMChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore
from elasticsearch import Elasticsearch

from core_api.src.auth import get_user_uuid
from redbox.llm.prompts.chat import (
    CONDENSE_QUESTION_PROMPT,
    STUFF_DOCUMENT_PROMPT,
    WITH_SOURCES_PROMPT,
)
from redbox.model_db import MODEL_PATH
from redbox.models import EmbeddingModelInfo, Settings
from redbox.models.chat import ChatRequest, ChatResponse, SourceDocument
from redbox.storage import ElasticsearchStorageHandler

env = Settings(_env_file="../.env")
env.elastic.host = "localhost"
env.minio_host = "localhost"

embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder="../models/")

es = Elasticsearch(
    hosts=[
        {
            "host": "localhost",
            "port": env.elastic.port,
            "scheme": env.elastic.scheme,
        }
    ],
    basic_auth=(env.elastic.user, env.elastic.password),
)

if env.elastic.subscription_level == "basic":
    strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
    strategy = ApproxRetrievalStrategy(hybrid=True)

vector_store = ElasticsearchStore(
    es_connection=es,
    index_name="redbox-data-chunk",
    embedding=embedding_model,
    strategy=strategy,
    vector_query_field="embedding",
)

llm = ChatLiteLLM(
    model=env.azure_openai_model,
    streaming=True,
    azure_key=env.azure_openai_api_key,
    api_version=env.openai_api_version,
    api_base=env.azure_openai_endpoint,
    max_tokens=4_096,
)

storage_handler = ElasticsearchStorageHandler(es_client=es, root_index=env.elastic_root_index)

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2


INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps


# Summarisation scratch

In [28]:
from langchain.schema import Document
from redbox.models.file import Metadata
from functools import reduce
from langchain_core.runnables import RunnableLambda, Runnable, RunnablePassthrough
from langchain.schema import StrOutputParser
from redbox.llm.prompts.core import _core_redbox_prompt
from operator import itemgetter

[Using this for LCEL map reduce inspo.](https://qubitpi.github.io/langchain/docs/modules/chains/document/map_reduce/)

In [33]:
from core_api.src.format import format_chunks, get_file_chunked_to_tokens
from core_api.src.runnables import make_stuff_document_runnable
from langchain.prompts import PromptTemplate

chunks = get_file_chunked_to_tokens(
    file_uuid=UUID("73d9d2e5-b3c2-459d-a69e-5688fb2d122f"),
    user_uuid=UUID("d48edb11-0f05-4d86-86fb-c91628ca2f28"), 
    storage_handler=storage_handler, 
    max_tokens=1_000
)

def make_map_reduce_runnable(
    system_prompt: str,
    llm: ChatLiteLLM,
) -> Runnable:
    """Takes a system prompt and LLM returns a map reduce runnable.

    Runnable takes input of a dict keyed to question, messages and documents.
    """
    map_runnable = make_stuff_document_runnable(
        system_prompt=system_prompt,
        llm=llm
    )
    map_chain = (
        {
            "question": itemgetter("question"),
            "messages": itemgetter("messages"),
            "documents": itemgetter("documents"),
        }
        | map_runnable
    )

    # def format_map_chain(summaries: list[dict[str, str]]):
    #     return "\n\n".join(summary["response"] for summary in summaries)

    # reduce_chain = (
    #     {"context": format_map_chain}
    #     | PromptTemplate.from_template("Combine these summaries: \n\n {context}")
    #     | llm
    #     | StrOutputParser()
    # )

    map_reduce_chain = (
        map_runnable.map()
    )

    # return (
    #     RunnablePassthrough()
    #     | {
    #         "response": map_chain.map() | reduce_chain,
    #         "sources": itemgetter("documents")
    #     }
    # )
    return map_reduce_chain


INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search?scroll=5m [status:200 duration:0.008s]
INFO:elastic_transport.transport:POST http://localhost:9200/_search/scroll [status:200 duration:0.002s]
INFO:elastic_transport.transport:DELETE http://localhost:9200/_search/scroll [status:200 duration:0.001s]


In [34]:
chain = make_map_reduce_runnable(
    system_prompt="Your job is summarisation.",
    llm=llm
)

chain.invoke(
    input={
        "question": "Tell me some interesting questions I could ask about this document.",
        "documents": chunks,
        "messages": [],
    }
)

TypeError: string indices must be integers, not 'str'

In [25]:
from core_api.src.runnables import make_stuff_document_runnable

chain = make_stuff_document_runnable(
    system_prompt="Your job is summarisation.",
    llm=llm
)

chain.invoke(
    input={
        "question": "Tell me some interesting questions I could ask about this document.",
        "documents": chunks[:3],
        "messages": [],
    }
)

[92m09:05:43 - LiteLLM:INFO[0m: utils.py:1307 - [92m

POST Request Sent from LiteLLM:
curl -X POST \
https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/ \
-H 'Authorization: Bearer a87792838865********************' \
-d '{'model': 'gpt-4', 'messages': [{'role': 'system', 'content': 'Your job is summarisation.'}, {'role': 'user', 'content': 'Question: Tell me some interesting questions I could ask about this document.. \n\n Documents: \n\n <Doc73d9d2e5-b3c2-459d-a69e-5688fb2d122f>\n 03/10/2023, 19:52\n\nCabinet Ofﬁce - WikipediaCabinet Office\n\nCoordinates: 51°30′13″N 0°7′36″W\n\nThe Cabinet Office is a department of the UK Government responsible for supporting the prime minister and Cabinet.[3] It is composed of various units that support Cabinet committees and coordinate the delivery of government objectives via other departments. As of December 2021, it had over 10,200 staff, mostly civil servants, some of whom work in Whitehall. Staff working in th

{'response': "1. What are the core functions of the UK's Cabinet Office?\n2. Who were the responsible ministers for the Cabinet Office as of March 2023?\n3. What are some historical highlights about the Cabinet Office?\n4. How many staff members are employed by the Cabinet Office?\n5. How does the Cabinet Office support the Prime Minister and the Cabinet?\n6. When was the Cabinet Office formed and what was its preceding department?\n7. Can you explain the key roles and responsibilities of each Minister responsible for the Cabinet Office?\n8. What are the responsibilities of the Cabinet Office at the UK national level?\n9. Can you share any details on the architecture and history of the Cabinet Office's building at 70 Whitehall?\n10. Who are the Cabinet Office senior civil servants and what are their positions as of March 2023?\n11. What is the role of the Parliamentary Private Secretary to the Cabinet Office? \n12. How do Cabinet committees support the Cabinet Office?",
 'sources': [Ch

In [20]:
chain = make_stuff_document_runnable(
    system_prompt="Your job is summarisation.",
    llm=llm
)

chain.map().invoke(
    input=[
        {
            "question": "Tell me some interesting questions I could ask about this document.",
            "documents": [chunk],
            "messages": [],
        }
        for chunk in chunks
    ]
)

[92m08:42:57 - LiteLLM:INFO[0m: utils.py:1307 - [92m

POST Request Sent from LiteLLM:
curl -X POST \
https://oai-i-dot-ai-playground-sweden.openai.azure.com//openai/deployments/gpt-4/ \
-H 'Authorization: Bearer a87792838865********************' \
-d '{'model': 'gpt-4', 'messages': [{'role': 'system', 'content': 'Your job is summarisation.'}, {'role': 'user', 'content': 'Question: Tell me some interesting questions I could ask about this document.. \n\n Documents: \n\n <Doc73d9d2e5-b3c2-459d-a69e-5688fb2d122f>\n Management and delivery of the Government\'s legislative programme (through the House of Lords) and facilitating the passage of individual bills; Leading the House (in the Chamber and as a key member of domestic committees to do with procedure, conduct, and the internal governance of\n\n3/6\n\n03/10/2023, 19:52\n\nCabinet Ofﬁce - Wikipediathe House); Issues connected to the House of Lords and its governance; Speaking for the Government in the Chamber on a range of issues, in

["1. What is the main purpose of the Cabinet Office in the UK government?\n2. What are the core functions of the Cabinet Office?\n3. What is the history and formation details of the Cabinet Office?\n4. How many employees does the Cabinet Office have as of December 2021?\n5. What is the Cabinet Office's role in UK's cyber security and crisis response?\n6. How does the Cabinet Office contribute to efficiency and reform across other government departments?\n7. Who are the key ministers and executives in the Cabinet Office?\n8. What are the miscellaneous units within the Cabinet Office and what are their roles?\n9. What are the responsibilities of the Cabinet Office at the UK national level?\n10. How does the Cabinet Office influence policy based on government priorities? \n11. What is the Cabinet Office's role in government procurement policy?\n12. How does the Cabinet Office Support the National Security Council and the Joint Intelligence Organisation? \n13. What is the role of the Cabin

In [5]:

def get_file_as_documents(
    file_uuid: UUID,
    user_uuid: UUID,
    storage_handler: ElasticsearchStorageHandler = storage_handler,
    max_tokens: int | None = None
) -> list[Document]:
    """Gets a file as LangChain Documents, splitting it by max_tokens."""
    documents: list[Document] = []
    chunks_unsorted = storage_handler.get_file_chunks(parent_file_uuid=file_uuid, user_uuid=user_uuid)
    chunks = sorted(chunks_unsorted, key=lambda x: x.index)

    total_tokens = sum(chunk.token_count for chunk in chunks)
    print(total_tokens)

    token_count: int = 0
    n = max_tokens or float("inf")
    page_content: list[str] = []
    metadata: list[Metadata | None] = []

    for chunk in chunks:
        if token_count + chunk.token_count >= n:
            document = Document(
                page_content=" ".join(page_content),
                metadata=reduce(Metadata.merge, metadata),
            )
            documents.append(document)
            token_count = 0
            page_content = []
            metadata = []

        page_content.append(chunk.text)
        metadata.append(chunk.metadata)
        token_count += chunk.token_count

    if len(page_content) > 0:
        document = Document(
            page_content=" ".join(page_content),
            metadata=reduce(Metadata.merge, metadata),
        )
        documents.append(document)

    return documents


In [None]:
def summarise(
    chat_request: ChatRequest,
    file_uuid: UUID,
    user_uuid: UUID,
    llm: ChatLiteLLM,
    storage_handler: ElasticsearchStorageHandler,
) -> ChatResponse:
    question = chat_request.message_history[-1].text
    previous_history = list(chat_request.message_history[:-1])
    
    # get full doc from vector store
    documents = get_file_as_documents(
        file_uuid=file_uuid, 
        user_uuid=user_uuid, 
        storage_handler=storage_handler,
        max_tokens=20_000
    )
    
    # right now, can only handle a single document so we manually truncate
    document = documents[:1]
    if len(documents) > 1:
        print("Document was longer than 20k tokens. Truncating to the first 20k.")
    
    # stuff raw prompt
    chat_history = [
        ("system", _core_redbox_prompt),
        ("placeholder", "{messages}"),
        ("user", "Question: {question}. \n\n Content: \n\n<document> {content} </document> \n\n Answer: "),
    ]

    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    chain = (
        {
            "question": itemgetter("question"),
            "messages": itemgetter("messages"),
            "content": itemgetter("content") | RunnableLambda(format_docs),
        }
        | ChatPromptTemplate.from_messages(chat_history)
        | llm
        | StrOutputParser()
    )

    # return
    return chain.invoke(
        input={
            "question": question,
            "content": document,
            "messages": [(msg.role, msg.text) for msg in previous_history]
        }
    )


chat_request_body = {
    "message_history": [
        # {"text": "Can you always refer to BEIS as the Department for Business, Energy and Industrial Strategy from now on?", "role": "user"},
        # {"text": "Of course. In future responses I will always expand the BEIS acronym.", "role": "ai"},
        {"text": "Please summarise all the key people in this document and who they work for.", "role": "user"},
    ]
}

res = summarise(
    chat_request=ChatRequest(**chat_request_body),
    file_uuid=UUID("35b3d95f-7f65-4cae-b159-22001ca19c88"),
    user_uuid=UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"),
    llm=llm,
    storage_handler=storage_handler
)

print(res)

In [19]:
from redbox.models.file import Metadata, Chunk
from functools import partial, reduce
from uuid import UUID

chunks_unsorted = storage_handler.get_file_chunks(
    parent_file_uuid=UUID("35b3d95f-7f65-4cae-b159-22001ca19c88"), 
    user_uuid=UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5")
)
chunks = sorted(chunks_unsorted, key=lambda x: x.index)

def reduce_chunks_by_tokens(chunks: list[Chunk] | None, chunk: Chunk, max_tokens: int) -> list[Chunk]:
    """"""
    if not chunks:
        return [chunk]
    
    last_chunk = chunks[-1]

    if chunk.token_count + last_chunk.token_count <= max_tokens:
        chunks[-1] = Chunk(
            parent_file_uuid=last_chunk.parent_file_uuid,
            index=last_chunk.index,
            text=last_chunk.text + chunk.text,
            metadata=Metadata.merge(last_chunk.metadata, chunk.metadata),
            creator_user_uuid=last_chunk.creator_user_uuid,
        )
    else:
        chunk.index = last_chunk.index + 1
        chunks.append(chunk)
    
    return chunks

reduce_chunk_t300 = partial(reduce_chunks_by_tokens, max_tokens=300)

result = reduce(lambda cs, c: reduce_chunk_t300(cs, c), chunks, [])

INFO:elastic_transport.transport:POST http://localhost:9200/redbox-data-chunk/_search?scroll=5m [status:200 duration:0.089s]
INFO:elastic_transport.transport:POST http://localhost:9200/_search/scroll [status:200 duration:0.102s]
INFO:elastic_transport.transport:POST http://localhost:9200/_search/scroll [status:200 duration:0.008s]
INFO:elastic_transport.transport:POST http://localhost:9200/_search/scroll [status:200 duration:0.003s]
INFO:elastic_transport.transport:DELETE http://localhost:9200/_search/scroll [status:200 duration:0.002s]


In [16]:
import numpy as np
len(chunks), max(chunk.token_count for chunk in chunks), min(chunk.token_count for chunk in chunks), np.mean([chunk.token_count for chunk in chunks])

(2066, 301, 1, 84.15295256534365)

In [21]:
import numpy as np
len(result), max(chunk.token_count for chunk in result), min(chunk.token_count for chunk in result), np.mean([chunk.token_count for chunk in result])

(685, 301, 111, 253.76058394160583)

In [20]:
[chunk.index for chunk in result]

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,
