In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import re
import string
from enum import Enum
from typing import List, Optional

import dotenv
from langchain.chains import create_extraction_chain
from langchain.chat_models import ChatAnthropic
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.output_parsers import (
    CommaSeparatedListOutputParser,
    PydanticOutputParser,
    RetryWithErrorOutputParser,
    XMLOutputParser,
)
from langchain.prompts import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.schema import HumanMessage, SystemMessage
from langchain.vectorstores import Chroma
from pydantic import (
    BaseModel,
    Field,
    ValidationError,
    computed_field,
    create_model,
    field_validator,
)
from pyprojroot import here

from redbox.llm.llm_base import LLMHandler
from redbox.models.classification import Tag, TagGroup
from redbox.models.file import File

ENV = dotenv.dotenv_values("../.env")

In [None]:
def alphabet_iterator(alpha=string.ascii_lowercase):
    letters = list(alpha)
    n = 0
    while True:
        yield letters[n]
        n += 1
        if n == len(letters):
            n = 0

In [None]:
alpha = alphabet_iterator()

In [None]:
alpha()

In [None]:
alpha = list(string.ascii_lowercase)
len(alpha)
alpha[25]

In [None]:
llm = ChatAnthropic(anthropic_api_key=ENV["ANTHROPIC_API_KEY"])

handler = LLMHandler(
    llm=llm,
    user_uuid="foo",
    vector_store=Chroma(
        embedding_function=SentenceTransformerEmbeddings(),
        persist_directory="../data/dev/db",
    ),
)

In [None]:
data_folder = os.path.join(here(), "data", "dev")
parsed_files_folder = os.path.join(data_folder, "file")
parsed_files = os.listdir(parsed_files_folder)
parsed_files

In [None]:
with open(os.path.join(data_folder, "file", parsed_files[0])) as f:
    email = json.load(f)["text"]

with open(os.path.join(data_folder, "file", parsed_files[1])) as f:
    speech = json.load(f)["text"]

with open(os.path.join(data_folder, "file", parsed_files[4])) as f:
    submission = json.load(f)["text"]

with open(os.path.join(data_folder, "file", parsed_files[2])) as f:
    minutes = json.load(f)["text"]

In [None]:
answer_regex = re.compile(r"(?<=<answer>)(.*?)(?=<\/answer>)")
alpha_regex = re.compile("[^a-zA-Z]")

# Generalising

Going to try and make an object so anyone can make cusom layers.

In [None]:
with open(os.path.join(data_folder, "file", parsed_files[0])) as f:
    email = File(**json.load(f))

with open(os.path.join(data_folder, "file", parsed_files[1])) as f:
    speech = File(**json.load(f))

with open(os.path.join(data_folder, "file", parsed_files[4])) as f:
    submission = File(**json.load(f))

with open(os.path.join(data_folder, "file", parsed_files[2])) as f:
    minutes = File(**json.load(f))

In [None]:
minutes

In [None]:
doc_category_dict = {
    "email": {
        "letter": "a",
        "description": "Email, letters and correspondance",
        "examples": [email],
    },
    "speech": {"letter": "b", "description": "Speech", "examples": [speech]},
    "minutes": {"letter": "c", "description": "Meeting minutes", "examples": [speech]},
    "submission": {
        "letter": "d",
        "description": "Submission or proposal",
        "examples": [minutes],
    },
    "other": {"letter": "e", "description": "Other, including documents"},
}

format = TagGroup(name="format", tags=[Tag(**v) for k, v in doc_category_dict.items()])

In [None]:
handler.classify_to_tag(
    group=format,
    raw_text="Dear Mr Sunak, please send the whole Civil Service LLMs for Christmas. Love, Will",
)

In [None]:
_doc_cat_example_template = PromptTemplate(
    input_variables=["document", "answer"],
    template="""\
        Document:
        <document>
        {document}
        </document>
        Assistant: My answer is {{{{'letter': '{answer}'}}}}\
    """,
)

_doc_cat_prefix_template = """\
You are a customer service agent that is classifying documents. \
The document is wrapped in <document></document> XML tags.

Categories are:

{layer_list_items}\
"""

_get_doc_subject_suffix = """\

Here is the document, wrapped in <document></document> XML tags
<document>
{raw_text}
</document>

{format_instructions} \

Assistant: My answer is {{'letter': '\
"""


class Tag(BaseModel):
    letter: str
    description: str
    examples: Optional[List[File]] = None

    @field_validator("letter")
    @classmethod
    def letter_to_upper(cls, v: str) -> str:
        return v.upper()

    @computed_field
    def var(self) -> str:
        alpha_regex = re.compile("[^a-zA-Z_]")
        space_to_score = self.description.replace(" ", "_").lower()
        return re.sub(alpha_regex, "", space_to_score)

    def get_examples(self):
        examples = []
        for example in self.examples:
            to_add = {"document": example.text, "answer": self.letter}
            examples.append(to_add)
        return examples

    def get_list_item(self):
        return f"({self.letter}) {self.description}"


class TagGroup(BaseModel):
    name: str
    tags: List[Tag]

    def get_examples(self):
        examples = []
        for tag in self.tags:
            if tag.examples is not None:
                examples += tag.get_examples()
        return examples

    def get_list_items(self):
        list_items = ""
        for tag in self.tags:
            list_items += tag.get_list_item() + " \n"
        return list_items

    def get_letters(self):
        return [tag.letter for tag in self.tags]

    def make_validator(self):
        def letter_validator(cls, v):
            assert v in self.get_letters(), description
            return v

        description = (
            "Must be a single uppercase letter of the alphabet corresponding "
            "to one of the following: \n\n"
            f"{self.get_list_items()}"
        )

        validators = {"letter_validator": field_validator("letter")(letter_validator)}

        return create_model(self.name, letter=(str, ...), __validators__=validators)

    def get_tag(self, letter):
        validator = self.make_validator()
        validator(letter=letter)
        for tag in self.tags:
            if tag.letter == letter:
                return tag

    def get_parser(self):
        return PydanticOutputParser(pydantic_object=self.make_validator())

    def get_classification_prompt_template(self, parser=None):
        if parser is None:
            parser = self.get_parser()

        return FewShotPromptTemplate(
            examples=self.get_examples(),
            example_prompt=_doc_cat_example_template,
            prefix=_doc_cat_prefix_template,
            suffix=_get_doc_subject_suffix,
            input_variables=["raw_text"],
            partial_variables={
                "layer_list_items": self.get_list_items(),
                "format_instructions": parser.get_format_instructions(),
            },
        )

In [None]:
doc_category_dict = {
    "email": {
        "letter": "a",
        "description": "Email, letters and correspondance",
        "examples": [email],
    },
    "speech": {"letter": "b", "description": "Speech", "examples": [speech]},
    "minutes": {"letter": "c", "description": "Meeting minutes", "examples": [speech]},
    "submission": {
        "letter": "d",
        "description": "Submission or proposal",
        "examples": [minutes],
    },
    "other": {"letter": "e", "description": "Other, including documents"},
}

format = TagGroup(name="format", tags=[Tag(**v) for k, v in doc_category_dict.items()])

In [None]:
parser = format.get_parser()
prompt = format.get_classification_prompt_template()

In [None]:
input_prompt = prompt.format_prompt(
    raw_text="Dear Santa, this is my email. Love William"
)
attempt_count_max = 5

try:
    output = llm([HumanMessage(content=input_prompt.text)])
    detected_class = parser.parse(output.content)
except ValueError as parse_error:
    print(
        f"Encountered error with first metadata extraction attempt: {str(parse_error)}"
    )
    attempt_count = 0

    retry_parser = RetryWithErrorOutputParser.from_llm(parser=parser, llm=llm)

    metadata = None

    while attempt_count < attempt_count_max:
        try:
            detected_class = retry_parser.parse_with_prompt(
                completion=output.content, prompt_value=input_prompt
            )
            break
        except ValueError as parse_retry_errror:
            print(f"Failed to rectify malformed data object: {str(parse_retry_errror)}")
            attempt_count += 1

    if detected_class is not None:
        print(f"Sucessful extraction with {attempt_count+1} attempt(s)")
    else:
        print(f"Failed extraction with {attempt_count+1} attempt(s)")

my_tag = format.get_tag(detected_class.letter)
print(my_tag.description)

# Doc format experiments

## Email

In [None]:
handler.get_doc_category(email, type="format")

In [None]:
handler.get_doc_category(email, type="subject")

In [None]:
parser = XMLOutputParser(tags=["answer"])

to_send = HumanMessage(
    content=GET_DOCTYPE_PROMPT.format_prompt(raw_text=email).to_string()
)

result = llm([to_send])

result.content

answer_regex = re.compile(r"(?<=<answer>)(.*?)(?=<\/answer>)")
alpha_regex = re.compile("[^a-zA-Z ]")

all_answers = re.findall(answer_regex, result.content)

if len(all_answers) == 0:
    raise Exception(
        f"""
        No answer detected in response:
        {result.content}
    """
    )

out = all_answers[0].split(")")

In [None]:
out = [alpha_regex.sub("", i.strip()) for i in out]
out

In [None]:
DocType(category=out[0])

## Speech

In [None]:
handler.get_doc_category(speech, type="format")

In [None]:
handler.get_doc_category(speech, type="subject")

In [None]:
to_send = HumanMessage(
    content=GET_DOCTYPE_PROMPT.format_prompt(raw_text=speech).to_string()
)

result = llm([to_send])

res = re.findall(answer_regex, result.content)

letter, description = res[0].split(")")
letter = alpha_regex.sub("", letter)
description = description.strip()

(letter, description)

## Submission

In [None]:
handler.get_doc_category(submission, type="format")

In [None]:
handler.get_doc_category(submission, type="subject")

In [None]:
to_send = HumanMessage(
    content=GET_DOCTYPE_PROMPT.format_prompt(raw_text=submission).to_string()
)

result = llm([to_send])

res = re.findall(answer_regex, result.content)

letter, description = res[0].split(")")
letter = alpha_regex.sub("", letter)
description = description.strip()

(letter, description)

## Minutes

In [None]:
handler.get_doc_category(minutes, type="format")

In [None]:
handler.get_doc_category(minutes, type="subject")

In [None]:
to_send = HumanMessage(
    content=GET_DOCTYPE_PROMPT.format_prompt(raw_text=minutes).to_string()
)

result = llm([to_send])

res = re.findall(answer_regex, result.content)

letter, description = res[0].split(")")
letter = alpha_regex.sub("", letter)
description = description.strip()

(letter, description)

## As a function

In [None]:
class DocCategory(Enum):
    email = "email"
    speech = "speech"
    minutes = "minutes"
    submission = "submission"


class DocType(BaseModel):
    category: DocCategory

    @field_validator("category", mode="before")
    @classmethod
    def _flexible_cat(cls, v: str) -> str:
        v = v.lower().strip()
        if v in ["a", "emails"]:
            return "email"
        elif v in ["b", "speeches"]:
            return "speech"
        elif v in ["c", "minute", "meetng", "meeting minutes"]:
            return "minutes"
        elif v in ["c", "submissions"]:
            return "submission"
        else:
            return v


doctype_parser = PydanticOutputParser(pydantic_object=DocType)

retry_parser = RetryWithErrorOutputParser.from_llm(parser=doctype_parser, llm=llm)

In [None]:
def try_retry(attempt_count_max: int = 5):
    def inner(func):
        def func_with_retry(*args, **kwargs):
            attempt_count = 0

            success = False

            while attempt_count < attempt_count_max and not success:
                try:
                    res = func(*args, **kwargs)
                    success = True
                except ValueError as e:
                    print(f"Failed to rectify malformed data object: {str(e)}")
                    attempt_count += 1

            if res is not None:
                print(f"Sucessful extraction with {attempt_count+1} attempt(s)")
            else:
                print(f"Failed extraction with {attempt_count+1} attempt(s)")

            return res

        return func_with_retry

    return inner

In [None]:
@try_retry(attempt_count_max=5)
def get_doctype(raw_text):
    to_send = HumanMessage(
        content=GET_DOCTYPE_PROMPT.format_prompt(raw_text=raw_text).to_string()
    )
    result = llm([to_send])

    answer_regex = re.compile(r"(?<=<answer>)(.*?)(?=<\/answer>)")
    alpha_regex = re.compile("[^a-zA-Z ]")

    all_answers = re.findall(answer_regex, result.content.replace("\n", ""))

    if len(all_answers) == 0:
        raise ValueError(
            f"""
            No answer detected in response:
            {result.content}
        """
        )

    out = all_answers[0].split(")")

    if len(out) == 0:
        raise ValueError("No category detected")

    out = [alpha_regex.sub("", i.strip()) for i in out]

    for i in out:
        try:
            res = DocType(category=i)
            return res
        except ValidationError:
            continue

    raise ValueError("No category detected")

In [None]:
get_doctype(email)

In [None]:
get_doctype(
    "Dear Caolm, here is my text I am sending you. Please forweard it to my mumn, love Will"
)

# Doc subject experiments