In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict

import dotenv
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatAnthropic
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from pyprojroot import here

import redbox.llm.spotlight.spotlight as spotlight_formats
from redbox.llm.llm_base import LLMHandler
from redbox.llm.prompts.spotlight import SPOTLIGHT_COMBINATION_TASK_PROMPT
from redbox.models.file import File
from redbox.models.spotlight import Spotlight

ENV = dotenv.dotenv_values("../.env")

test_user_info = {
    "name": "Liam Wilkinson",
    "email": "liam.wilkinson@cabinetoffice.gov.uk",
    "department": "Cabinet Office",
    "role": "Economic Policy",
    "preferred_language": "British English",
}

In [None]:
SPOTLIGHT_COMBINATION_TASK_PROMPT.json()

In [None]:
llm = ChatAnthropic(anthropic_api_key=ENV["ANTHROPIC_API_KEY"])

handler = LLMHandler(
    llm=llm,
    user_uuid="foo",
    vector_store=Chroma(
        embedding_function=SentenceTransformerEmbeddings(),
        persist_directory="../data/dev/db",
    ),
)

In [None]:
file_dir = Path("/Users/willlangdale/DS/10ds-ai-redbox/data/dev/file")
json_list = list(file_dir.glob("**/*.json"))
files = []
for file in json_list:
    with open(Path(file_dir / file)) as f:
        files.append(File(**json.load(f)))

In [None]:
spotlight.tasks[0].model_dump()

In [None]:
spotlight = Spotlight(
    files=files,
    file_hash="abc",
    formats=[
        spotlight_formats.email_format,
        spotlight_formats.meeting_format,
        spotlight_formats.briefing_format,
        spotlight_formats.proposal_format,
        spotlight_formats.other_format,
    ],
)
spotlight

In [None]:
for task in spotlight.tasks:
    print(task.id)
    llm_chain = LLMChain(llm=handler.llm, prompt=task.prompt_template)
    chain = StuffDocumentsChain(llm_chain=llm_chain, document_variable_name="text")
    res = chain.run(
        user_info=test_user_info,
        current_datetime=datetime.now().isoformat(),
        input_documents=spotlight.to_documents(),
    )
    print(res)
    break

In [None]:
class MyHandler(BaseCallbackHandler):
    def __init__(self, count_element):
        self.count_element = count_element
        self.initial_count = 0

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        self.count_element += 1


handler = MyHandler()

for task in spotlight.tasks:
    print(task.id)
    map_chain = LLMChain(llm=llm, prompt=task.prompt_template)
    reduce_chain = LLMChain(llm=llm, prompt=SPOTLIGHT_COMBINATION_TASK_PROMPT)

    combine_documents_chain = StuffDocumentsChain(
        llm_chain=reduce_chain, document_variable_name="text"
    )
    reduce_documents_chain = ReduceDocumentsChain(
        combine_documents_chain=combine_documents_chain,
        collapse_documents_chain=combine_documents_chain,
        token_max=100000,
    )
    map_reduce_chain = MapReduceDocumentsChain(
        llm_chain=map_chain,
        reduce_documents_chain=reduce_documents_chain,
        document_variable_name="text",
        return_intermediate_steps=False,
    )
    res = map_reduce_chain.run(
        user_info=test_user_info,
        current_datetime=datetime.now().isoformat(),
        input_documents=spotlight.to_documents(),
        callbacks=[handler],
    )
    print(res)
    break

In [None]:
dir(map_reduce_chain.Config)

In [None]:
dir(map_reduce_chain)

In [None]:
map_reduce_chain.json()

In [None]:
data_folder = os.path.join(here(), "data", "dev")
spotlight_data_folder = os.path.join(data_folder, "spotlight")
spotlight_dir = Path(spotlight_data_folder)
spotlight_completed = {}

for complete in list(spotlight_dir.glob("**/*.json")):
    spotlight_completed[complete.stem] = complete

spotlight_completed