# predict

> Use chroma, labeled instances, an LLM to predict email classification

In [1]:
#| default_exp predict

In [22]:
#| export
from pathlib import Path
import json
from typing import List, Tuple
import time
from google.api_core.exceptions import ResourceExhausted
from ratelimit import sleep_and_retry

import chromadb
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
from langchain.document_loaders import DataFrameLoader

from classifier.schema import predict, WRITE_PREFIX, PROJECT_BUCKET
from classifier.load import get_possible_labels, get_training_instances, get_idx, LABEL_COLUMN
from classifier.process import BISON_MAXIMUM_INPUT_TOKENS
from classifier.chroma import get_or_make_chroma, merge_summaries_with_instances, get_embedder, \
    read_json_lines_from_gcs

In [2]:
from tqdm import tqdm
import pandas as pd

In [3]:
data_dir = Path("../data")
assert data_dir.exists()

Load chroma with embedded summaries of labeled emails

In [4]:
chroma = get_or_make_chroma(data_dir)

In [5]:
chroma_10k = Chroma(
    collection_name="emails",
    client=chromadb.PersistentClient(path=str((data_dir / "chroma_10k").resolve())),
    embedding_function=get_embedder()
)

Load our test summaries

In [6]:
summary_path = data_dir / "summaries.csv"
assert summary_path.exists()
summaries = pd.read_csv(summary_path)
summaries.head(5)

Unnamed: 0,summary
0,The customer received an invoice from Cardina...
1,The customer received an email from the State...
2,The customer would like to place a new order ...
3,"The customer, State of Florida Next Gen, upda..."
4,The customer is inquiring about an order plac...


In [7]:
training_instances = list(get_training_instances())

In [8]:
chroma_document_frame = merge_summaries_with_instances(
    summaries,
    training_instances
)
chroma_document_frame.head(2)

Unnamed: 0_level_0,summary,BU,case_number,ACCOUNT_BUSINESS_UNIT__C,received_at,sfdc_subcategory,predicted_category,predicted_subcategory,record_type,probability,Accuracy_upd,Bin,label,email_subject,email_body
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
20775,The customer received an invoice from Cardina...,PD,3553288,a1G4z00000H6C4aEAF,2023-10-13T12:37:20,Billing Statements,Billing / Invoice,,1,0.474032,Correct,4,Billing / Invoice,"Invoices 1 of 1 for 2057199110 , TEXAS INSTITU...","Dear Valued Customer, Your Cardinal Health in..."
46774,The customer received an email from the State...,PD,3658829,,2023-11-29T20:25:47,Account updates,Billing / Invoice,,1,0.566661,Incorrect,5,Account/Inquiry,Invoice status from State of Florida Next Gen,External Email â€“ Please use caution before o...


In [9]:
train_idx, test_idx = get_idx()

In [10]:
test_frame = chroma_document_frame.loc[test_idx, :]
test_frame.head(2)

Unnamed: 0_level_0,summary,BU,case_number,ACCOUNT_BUSINESS_UNIT__C,received_at,sfdc_subcategory,predicted_category,predicted_subcategory,record_type,probability,Accuracy_upd,Bin,label,email_subject,email_body
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
50493,- Customer inquiry about a possible overage o...,PD,3671369,,2023-12-05T16:02:05,Account balance,Billing / Invoice,,1,0.349415,Correct,3,Billing / Invoice,Customer Inquiry Possible overage CINI 5813233...,Hello Team\n\nCan you please reach out to the ...
13780,"The customer, Lauren Walsh from Sanofi, sent ...",SPD,3524801,,2023-10-02T19:22:23,Order Placement,General Inquiry,,2,0.73258,Correct,7,General Inquiry,new po,"Good afternoon, please reach out to acct #215..."


In [11]:
test_documents = DataFrameLoader(test_frame.reset_index(drop=False), page_content_column='summary').load()

10k test frame

In [12]:
batch_result_file_uri = "JDB_experiments/summarization/prediction-model-2023-12-18T15:10:57.834767Z/000000000000.jsonl"

batch_result = list(read_json_lines_from_gcs(batch_result_file_uri))

In [14]:
batch_result_summaries = [
    r.get('predictions', [{}])[0].get("content", "").strip() for r in batch_result]

In [13]:
batch_result_metadata = list(read_json_lines_from_gcs(
    "JDB_experiments/summarization_metadata.jsonl"
))

In [15]:
batch_result_dataframe = pd.DataFrame.from_records(batch_result_metadata)
batch_result_dataframe.loc[:, 'summary'] = pd.Series(batch_result_summaries)

In [18]:
# Add label
training_data = pd.read_excel(
    f"gs://{PROJECT_BUCKET}/Last50KCases_withSubjectAndBody.xlsx")
training_data.loc[:, 'email_subject'] = training_data.email_subject.fillna("N/A")

In [23]:
batch_result_labels = training_data.loc[batch_result_dataframe.idx, LABEL_COLUMN]
batch_result_dataframe.loc[:, 'label'] = batch_result_labels.tolist()

In [28]:
train_10k_idx, test_2k_idx = get_idx(prefix=WRITE_PREFIX + "/summarization_idx")
test_2k_idx.shape

(1857,)

In [47]:
test_2k_documents_df = batch_result_dataframe.set_index('idx').loc[test_2k_idx, :].reset_index(drop=False)
test_2k_documents_df.head(2)

Unnamed: 0,idx,BU,case_number,ACCOUNT_BUSINESS_UNIT__C,received_at,sfdc_subcategory,predicted_category,predicted_subcategory,record_type,probability,Accuracy_upd,Bin,summary,label
0,37508,PD,3623167,,2023-11-13T18:59:03,Drop Ship Order,Returns,,1,0.636091,Incorrect,6,"The customer, State of Florida Next Gen, recei...",Order Processing
1,1092,PD,3473751,a1G4z00000H4yOtEAJ,2023-09-12T14:18:30,Verify status,Delivery,,1,0.174234,Incorrect,1,"The customer, Sachin Sharma, from Metro Medica...",Order Processing


In [48]:
test_2k_documents = DataFrameLoader(test_2k_documents_df, 'summary').load()
test_2k_documents[0]

Document(page_content='The customer, State of Florida Next Gen, received invoice 7323598948 from the sender.\nThe invoice has been updated by the customer in the SAP Business Network.\nThere is an accounting verification exception in the invoice.\nThe customer is advised to accept or edit the accounting information.\nFor any questions, the customer should contact the sender.', metadata={'idx': 37508, 'BU': 'PD', 'case_number': 3623167, 'ACCOUNT_BUSINESS_UNIT__C': nan, 'received_at': '2023-11-13T18:59:03', 'sfdc_subcategory': 'Drop Ship Order', 'predicted_category': 'Returns', 'predicted_subcategory': nan, 'record_type': 1, 'probability': 0.63609123, 'Accuracy_upd': 'Incorrect', 'Bin': 6, 'label': 'Order Processing'})

## Prediction prompt

In [33]:
labels = get_possible_labels()
labels

['Order Processing',
 'Product Inquiry',
 'Account/Inquiry',
 'General Inquiry',
 'Returns',
 'Billing / Invoice',
 'Delivery',
 'Credits',
 'Order Discrepancy',
 'Pricing',
 'Program / Promotions']

In [34]:
#| export
EMAIL_LABEL_SEP = "|||"

LABEL_STR = """- Order Processing
- Product Inquiry
- Account/Inquiry
- General Inquiry
- Returns
- Billing / Invoice
- Delivery
- Credits
- Order Discrepancy
- Pricing
- Program / Promotions

"""

PREDICTION_PROMPT_TEMPLATE = """\
Our customer service team wants to classify emails so they can be sent to the right support team.
Here are the labels they use.

--LABELS--
""" + LABEL_STR + """

Below are a series of emails that have already been labeled, use their example to identify what label the final email should get.
Your answer must be one of the options in the --LABELS-- list.
Return only the label from the above list that you chose.

--EMAILS--
{examples}
EMAIL: {email} """ + f"{EMAIL_LABEL_SEP} LABEL: "

PREDICTION_PROMPT = PromptTemplate.from_template(PREDICTION_PROMPT_TEMPLATE)

In [35]:
PREDICTION_PROMPT

PromptTemplate(input_variables=['email', 'examples'], template='Our customer service team wants to classify emails so they can be sent to the right support team.\nHere are the labels they use.\n\n--LABELS--\n- Order Processing\n- Product Inquiry\n- Account/Inquiry\n- General Inquiry\n- Returns\n- Billing / Invoice\n- Delivery\n- Credits\n- Order Discrepancy\n- Pricing\n- Program / Promotions\n\n\n\nBelow are a series of emails that have already been labeled, use their example to identify what label the final email should get.\nYour answer must be one of the options in the --LABELS-- list.\nReturn only the label from the above list that you chose.\n\n--EMAILS--\n{examples}\nEMAIL: {email} ||| LABEL: ')

In [36]:
#| export
def filter_examples(examples: List[Document], idx: int) -> List[Document]:
    return [e for e in examples if int(e.metadata.get('idx')) != int(idx)]

In [37]:
example_retrieved_documents = chroma.similarity_search(
    test_documents[0].page_content, 
    k=3)
len(example_retrieved_documents)

3

In [38]:
len(filter_examples(example_retrieved_documents, test_documents[0].metadata['idx']))

3

In [39]:
len(chroma.get()['ids'])

900

In [40]:
#| export
def format_example(example: Document) -> str:
    return f"EMAIL: {example.page_content.strip()} {EMAIL_LABEL_SEP} LABEL: {example.metadata.get('label')}"


def make_prediction_prompt(
        email_summary: Document,
        chroma: Chroma,
        limit: int = None
) -> str:
    idx = email_summary.metadata.get('idx')
    k = 5
    prompt = None
    keep_stuffing = True
    max_k = len(chroma.get()['ids']) if limit is None else limit
    while keep_stuffing:
        examples = chroma.similarity_search(email_summary.page_content, k=k)
        examples = filter_examples(examples, idx)
        if len(examples) == 0:
            k += 5
            continue
        else:
            example_str = ""
            for e in examples:
                e_formatted = format_example(e)
                if len(example_str) == 0:
                    example_str = e_formatted
                else:
                    example_str = example_str + "\n" + e_formatted
                prompt = PREDICTION_PROMPT.format(
                    email=email_summary.page_content,
                    examples=example_str
                )
                if len(prompt) >= BISON_MAXIMUM_INPUT_TOKENS:
                    keep_stuffing = False
                    break
        k += 5
        if k >= max_k:
            keep_stuffing = False
    return prompt

In [41]:
example_prompt = make_prediction_prompt(
    test_documents[0], 
    chroma=chroma,
    limit=3
)
print(len(example_prompt))
print(example_prompt)

2305
Our customer service team wants to classify emails so they can be sent to the right support team.
Here are the labels they use.

--LABELS--
- Order Processing
- Product Inquiry
- Account/Inquiry
- General Inquiry
- Returns
- Billing / Invoice
- Delivery
- Credits
- Order Discrepancy
- Pricing
- Program / Promotions



Below are a series of emails that have already been labeled, use their example to identify what label the final email should get.
Your answer must be one of the options in the --LABELS-- list.
Return only the label from the above list that you chose.

--EMAILS--
EMAIL: - Item 5681788 (NABI-HB SF 312U/ML 5ML) is short by 8 units, resulting in a total shortage value of $5,846.48.
- The last time this item was counted was on 9/21/2023.
- The customer wants to check with the accounts listed in the attached file to inquire about a possible overage of this item.
- The item hasn't had many sales recently. ||| LABEL: Order Discrepancy
EMAIL: The customer is reporting a short

In [42]:
predict(example_prompt)

 Order Discrepancy

In [69]:
example_10k_prompt = make_prediction_prompt(
    test_2k_documents[0], 
    chroma=chroma_10k,
    limit=5
)
print(len(example_10k_prompt))
print(example_10k_prompt)

2908
Our customer service team wants to classify emails so they can be sent to the right support team.
Here are the labels they use.

--LABELS--
- Order Processing
- Product Inquiry
- Account/Inquiry
- General Inquiry
- Returns
- Billing / Invoice
- Delivery
- Credits
- Order Discrepancy
- Pricing
- Program / Promotions



Below are a series of emails that have already been labeled, use their example to identify what label the final email should get.
Your answer must be one of the options in the --LABELS-- list.
Return only the label from the above list that you chose.

--EMAILS--
EMAIL: The customer, State of Florida Next Gen, received invoice 7334256889 from the sender.
The invoice has been updated by the customer in the SAP Business Network.
There is an accounting verification exception in the invoice.
The customer is advised to accept or edit the accounting information.
For any questions, the customer is advised to contact the sender. ||| LABEL: Account/Inquiry
EMAIL: The customer,

In [70]:
predict(example_10k_prompt)

 Billing / Invoice

Batch construct prompts

In [43]:
test_prompts = [
    make_prediction_prompt(s, chroma, 3) for s in tqdm(
        test_documents, 
        leave=False, 
        ncols=80)]

                                                                                

In [50]:
test_2k_prompts = [
    make_prediction_prompt(s, chroma_10k, 3) for s in tqdm(
        test_2k_documents, 
        leave=False, 
        ncols=80)]

                                                                                

In [51]:
#| export
def get_predictions(prompts: List[str]) -> List[str]:
    predictions = []
    for p in tqdm(prompts, ncols=80, leave=False):
        sleep_time = 1
        try:
            p_prediction = predict(p)
        except ResourceExhausted:
            while True:
                try:
                    p_prediction = predict(p)
                    break
                except ResourceExhausted:
                    time.sleep(sleep_time)
                    sleep_time = sleep_time * 2
        predictions.append(p_prediction.text.strip())
    return predictions

In [52]:
test_predictions = get_predictions(test_prompts)

                                                                                

In [53]:
test_predictions_10k = get_predictions(test_2k_prompts)

                                                                                

In [60]:
len(test_predictions), len(test_predictions_10k)

(100, 1857)

In [61]:
test_predictions[0], test_predictions_10k[0]

('Order Discrepancy', 'Billing / Invoice')

In [62]:
test_labels = [d.metadata.get('label') for d in test_documents]

In [63]:
pd.Series(test_predictions).value_counts()

Order Processing        39
Account/Inquiry         17
Order Discrepancy       14
Returns                  7
Billing / Invoice        6
Credits                  4
Delivery                 3
Pricing                  3
Product Inquiry          3
General Inquiry          3
Program / Promotions     1
Name: count, dtype: int64

In [64]:
pd.Series(test_predictions_10k).value_counts()

Order Processing        1005
Billing / Invoice        235
Order Discrepancy        150
Delivery                  98
Account/Inquiry           92
Returns                   79
Product Inquiry           69
Credits                   54
General Inquiry           39
Pricing                   33
Program / Promotions       3
Name: count, dtype: int64

In [65]:
#| export
def write_predictions(
        predictions: List[str],
        labels: List[str],
        idx: List[str],
        directory: Path,
        file_name: str = "predictions.csv"):
    pd.DataFrame(
        list(zip(
            predictions, 
            labels,
            idx)),
        columns=['pred', 'label', 'idx']
    ).to_csv(directory / file_name, index=False)

In [66]:
write_predictions(
    test_predictions, 
    [d.metadata.get('label') for d in test_documents],
    [d.metadata.get('idx') for d in test_documents],
    data_dir,
    'sample_predictions.csv')

In [67]:
write_predictions(
    test_predictions_10k, 
    [d.metadata.get('label') for d in test_2k_documents],
    [d.metadata.get('idx') for d in test_2k_documents],
    data_dir,
    'predictions_2k.csv')

In [68]:
#| hide
import nbdev; nbdev.nbdev_export()