# predict

> Use chroma, labeled instances, an LLM to predict email classification

In [1]:
#| default_exp predict

In [71]:
#| export
from pathlib import Path
import json
from typing import List, Tuple

from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma

from classifier.schema import predict
from classifier.load import get_possible_labels
from classifier.process import BISON_MAXIMUM_INPUT_TOKENS
from classifier.chroma import get_or_make_chroma, concat_email_summaries

In [113]:
from tqdm import tqdm
import pandas as pd

In [4]:
data_dir = Path("../data")
assert data_dir.exists()

Load chroma with embedded summaries of labeled emails

In [5]:
chroma = get_or_make_chroma(data_dir)

Load our summaries

In [11]:
summary_path = data_dir / "summaries.json"
assert summary_path.exists()
with summary_path.open('r') as f:
    summary_json = json.load(f)

In [13]:
summaries = concat_email_summaries(summary_json)
summaries[0]

Document(page_content='The email is requesting a drop ship order for Ohio State University.\nThe PO number is 7004014842, the account number is 2150126632, and the store number is 16422.\nThe drug name is EPIDIOLEX 100MG/ML SOL 100ML, the order quantity is 5, and the prescriber names are LUCRETIA LONG and PHILIP CLAYTON JONAS.\nThe prescriber NPIs or DEAs are ML0822634 and FJ142\n', metadata={'idx': 0, 'label': 'Order Processing'})

## Prediction prompt

In [18]:
labels = get_possible_labels()
labels

['Order Processing',
 'Product Inquiry',
 'Account/Inquiry',
 'General Inquiry',
 'Returns',
 'Billing / Invoice',
 'Delivery',
 'Credits',
 'Order Discrepancy',
 'Pricing',
 'Program / Promotions']

In [100]:
#| export
LABEL_STR = """- Order Processing
- Product Inquiry
- Account/Inquiry
- General Inquiry
- Returns
- Billing / Invoice
- Delivery
- Credits
- Order Discrepancy
- Pricing
- Program / Promotions
"""

PREDICTION_PROMPT_TEMPLATE = """\
Classify the input email into one of the given categories using the examples to help you
--CATEGORIES--
""" + LABEL_STR + """\
\n--INPUT EMAIL--
{email}
--EXAMPLES--
{examples}
--Classification--"""

PREDICTION_PROMPT = PromptTemplate.from_template(PREDICTION_PROMPT_TEMPLATE)

In [101]:
PREDICTION_PROMPT

PromptTemplate(input_variables=['email', 'examples'], template='Classify the input email into one of the given categories using the examples to help you\n--CATEGORIES--\n- Order Processing\n- Product Inquiry\n- Account/Inquiry\n- General Inquiry\n- Returns\n- Billing / Invoice\n- Delivery\n- Credits\n- Order Discrepancy\n- Pricing\n- Program / Promotions\n\n--INPUT EMAIL--\n{email}\n--EXAMPLES--\n{examples}\n--Classification--')

In [116]:
#| export
def filter_examples(examples: List[Tuple[Document, float]], idx: int) -> List[Document]:
    return [(e, score) for e, score in examples if int(e.metadata.get('idx')) != int(idx)]

In [118]:
example_retrieved_documents = chroma.similarity_search_with_score(summaries[0].page_content, k=10)
len(example_retrieved_documents)

10

In [117]:
len(filter_examples(example_retrieved_documents, 0))

9

In [103]:
len(chroma.get()['ids'])

20

In [104]:
#| export
def format_example(example: Document, score: float) -> str:
    return f"""--START EXAMPLE--
EXAMPLE EMAIL: {example.page_content.strip()}
EXAMPLE LABEL: {example.metadata.get('label')}
EXAMPLE COSINE DISTANCE TO INPUT EMAIL: {score}
--END EXAMPLE--"""

def make_prediction_prompt(
        email_summary: Document,
        chroma: Chroma
) -> str:
    idx = email_summary.metadata.get('idx')
    k = 5
    prompt = None
    keep_stuffing = True
    max_k = len(chroma.get()['ids'])
    while keep_stuffing:
        examples = chroma.similarity_search_with_score(email_summary.page_content, k=k)
        examples = filter_examples(examples, idx)
        if len(examples) == 0:
            k += 5
            continue
        else:
            example_str = ""
            for e, score in examples:
                e_formatted = format_example(e, score)
                if len(example_str) == 0:
                    example_str = e_formatted
                else:
                    example_str = example_str + "\n" + e_formatted
                prompt = PREDICTION_PROMPT.format(
                    email=email_summary.page_content,
                    examples=example_str
                )
                if len(prompt) >= BISON_MAXIMUM_INPUT_TOKENS:
                    keep_stuffing = False
                    break
        k += 5
        if k >= max_k:
            keep_stuffing = False
    return prompt

In [105]:
example_prompt = make_prediction_prompt(
    summaries[0],
    chroma=chroma
)
print(len(example_prompt))
print(example_prompt)

7225
Classify the input email into one of the given categories using the examples to help you
--CATEGORIES--
- Order Processing
- Product Inquiry
- Account/Inquiry
- General Inquiry
- Returns
- Billing / Invoice
- Delivery
- Credits
- Order Discrepancy
- Pricing
- Program / Promotions

--INPUT EMAIL--
The email is requesting a drop ship order for Ohio State University.
The PO number is 7004014842, the account number is 2150126632, and the store number is 16422.
The drug name is EPIDIOLEX 100MG/ML SOL 100ML, the order quantity is 5, and the prescriber names are LUCRETIA LONG and PHILIP CLAYTON JONAS.
The prescriber NPIs or DEAs are ML0822634 and FJ142

--EXAMPLES--
--START EXAMPLE--
EXAMPLE EMAIL: The email is requesting a drop ship order for the following:

- Client Name: Rosedale Infectious Diseases, PLLC
- PO ID: 7004000449
- Account #: 2150129609
- Store #: 16405
- NDC: 49702024015
- Drug Name: CABENUVA 600-900MG INJ SUSP KIT
- Order Quantity: 1
- Prescriber Name: ASHLEY DAY SCOTT
-

In [106]:
example_prompts = [make_prediction_prompt(s, chroma) for s in summaries]

In [121]:
example_predictions = []

for p in tqdm(example_prompts, ncols=80, leave=False):
    p_prediction = predict(p)
    example_predictions.append(p_prediction.text)

                                                                                

In [122]:
len(example_predictions)

20

In [123]:
summaries[0]

Document(page_content='The email is requesting a drop ship order for Ohio State University.\nThe PO number is 7004014842, the account number is 2150126632, and the store number is 16422.\nThe drug name is EPIDIOLEX 100MG/ML SOL 100ML, the order quantity is 5, and the prescriber names are LUCRETIA LONG and PHILIP CLAYTON JONAS.\nThe prescriber NPIs or DEAs are ML0822634 and FJ142\n', metadata={'idx': 0, 'label': 'Order Processing'})

In [124]:
example_labels = [d.metadata.get('label') for d in summaries]

In [125]:
pd.DataFrame(
    list(zip(
        example_predictions, 
        [d.metadata.get('label') for d in summaries],
        [d.metadata.get('idx') for d in summaries])),
    columns=['pred', 'label', 'idx']
).to_csv(data_dir / 'sample_predictions.csv', index=False)

In [126]:
#| hide
import nbdev; nbdev.nbdev_export()