In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os

import dotenv
from langchain.chat_models import ChatAnthropic
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.output_parsers import RetryWithErrorOutputParser
from langchain.vectorstores import Chroma
from pyprojroot import here

from redbox.llm.llm_base import LLMHandler
from redbox.models.classification import Tag, TagGroup
from redbox.models.file import File

ENV = dotenv.dotenv_values("../.env")

In [None]:
llm = ChatAnthropic(anthropic_api_key=ENV["ANTHROPIC_API_KEY"])

handler = LLMHandler(
    llm=llm,
    user_uuid="foo",
    vector_store=Chroma(
        embedding_function=SentenceTransformerEmbeddings(),
        persist_directory="../data/dev/db",
    ),
)

data_folder = os.path.join(here(), "data", "dev")
parsed_files_folder = os.path.join(data_folder, "file")
parsed_files = os.listdir(parsed_files_folder)
user_prefs_folder = os.path.join(here(), "data", "dev", "user_preferences")

# Tuning the classification

In [None]:
with open(os.path.join(data_folder, "file", parsed_files[0])) as f:
    email = File(**json.load(f))

with open(os.path.join(data_folder, "file", parsed_files[1])) as f:
    speech = File(**json.load(f))

with open(os.path.join(data_folder, "file", parsed_files[4])) as f:
    submission = File(**json.load(f))

with open(os.path.join(data_folder, "file", parsed_files[2])) as f:
    minutes = File(**json.load(f))

In [None]:
user_taggroups = []
for user_pref in os.listdir(user_prefs_folder):
    with open(os.path.join(user_prefs_folder, user_pref)) as f:
        user_taggroups.append(TagGroup(**json.load(f)))

In [None]:
handler.classify_to_tag(group=user_taggroups[0], raw_text=minutes.text)

In [None]:
group = user_taggroups[0]
raw_text = minutes.text
attempt_count_max = 5

parser = group.get_parser()
prompt = group.get_classification_prompt_template()

input_prompt = prompt.format_prompt(raw_text=raw_text)

try:
    output = llm([HumanMessage(content=input_prompt.text)])
    detected_class = parser.parse(output.content)
except ValueError as parse_error:
    print(
        f"Encountered error with first metadata extraction attempt: {str(parse_error)}"
    )
    attempt_count = 0

    retry_parser = RetryWithErrorOutputParser.from_llm(parser=parser, llm=llm)

    detected_class = None

    while attempt_count < attempt_count_max:
        try:
            detected_class = retry_parser.parse_with_prompt(
                completion=output.content, prompt_value=input_prompt
            )
            break
        except ValueError as parse_retry_errror:
            print(f"Failed to rectify malformed data object: {str(parse_retry_errror)}")
            attempt_count += 1

    if detected_class is not None:
        print(f"Sucessful extraction with {attempt_count+1} attempt(s)")
    else:
        print(f"Failed extraction with {attempt_count+1} attempt(s)")

my_tag = group.get_tag(detected_class.letter)

Tag(letter=my_tag.letter, description=my_tag.description)