# predict

> Use chroma, labeled instances, an LLM to predict email classification

In [1]:
#| default_exp predict

In [107]:
#| export
from pathlib import Path
import json
from typing import List, Tuple
import time
from tqdm import tqdm
import pandas as pd

import chromadb
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
from langchain.document_loaders import DataFrameLoader
from langchain.llms import VertexAI

from classifier.schema import predict, quota_handler, WRITE_PREFIX, PROJECT_BUCKET
from classifier.load import get_possible_labels, get_emails_from_frame, get_idx, LABEL_COLUMN, \
    get_raw_emails_tejas_case_numbers, get_batches
from classifier.process import BISON_MAXIMUM_INPUT_TOKENS
from classifier.chroma import get_or_make_chroma, get_embedder, \
    read_json_lines_from_gcs

In [6]:
data_dir = Path("../data")
assert data_dir.exists()
tejas_dir = data_dir / "tejas"
assert tejas_dir.exists() and len(list(tejas_dir.glob("*"))) > 0

## Load chroma with embedded summaries of labeled emails

In [7]:
chroma = get_or_make_chroma(tejas_dir)

In [8]:
# chroma_10k = Chroma(
#     collection_name="emails",
#     client=chromadb.PersistentClient(path=str((data_dir / "chroma_10k").resolve())),
#     embedding_function=get_embedder()
# )

## Load our test summaries

In [10]:
## Local
# summary_path = data_dir / "summaries.csv"
# assert summary_path.exists()
# summaries = pd.read_csv(summary_path)
# summaries.head(5)

In [9]:
summaries_path = f'gs://{PROJECT_BUCKET}/{WRITE_PREFIX}/tejas/summaries.csv'
summaries = pd.read_csv(summaries_path)
summaries.rename({"Unnamed: 0": "idx"}, axis=1, inplace=True)
summaries.head(5)

Unnamed: 0,idx,summary
0,31716,Pavlina Georgieva (Logistics Coordinator) sen...
1,35200,**Summary**\n\nA customer reached out to Card...
2,462,**Subject**: Invoice 7322207358 - State of Fl...
3,3705,**Subject**: ACTION REQUIRED | Additional Inf...
4,25300,**Subject**: Paid - Invoice 7328757492\n\n**C...


In [11]:
train_idx, test_idx = get_idx(prefix=f"{WRITE_PREFIX}/tejas")

In [13]:
raw_emails_tejas = get_raw_emails_tejas_case_numbers()
emails = list(get_emails_from_frame(
    raw_emails_tejas,
    index_prefix=f'{WRITE_PREFIX}/tejas'
))

In [15]:
emails_frame = pd.DataFrame([e.to_series() for e in emails])
chroma_document_frame = emails_frame.merge(summaries, on='idx', how='inner').set_index('idx')
chroma_document_frame.head(2)

Unnamed: 0_level_0,BU,case_number,ACCOUNT_BUSINESS_UNIT__C,received_at,sfdc_subcategory,predicted_category,predicted_subcategory,record_type,probability,Accuracy_upd,Bin,label,email_subject,email_body,summary
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
31716,PD,3598350,,2023-11-01T19:40:57,Drop Ship Order,Order Processing,Drop Ship Order,1,0.576672,Correct,5,Order Processing,Equashield latest - FW: EQ II Catalog 2023 - C...,External Email â€“ Please use caution before o...,Pavlina Georgieva (Logistics Coordinator) sen...
35200,PD,3613116,,2023-11-08T17:27:04,Account balance,Billing / Invoice,,1,0.496874,Correct,4,Billing / Invoice,Auto-Reply. We Have Received Your Request,"To whom it may concern, Your request has been...",**Summary**\n\nA customer reached out to Card...


In [18]:
test_documents_frame = chroma_document_frame.loc[test_idx, :]
test_documents_frame.head(2)

Unnamed: 0_level_0,BU,case_number,ACCOUNT_BUSINESS_UNIT__C,received_at,sfdc_subcategory,predicted_category,predicted_subcategory,record_type,probability,Accuracy_upd,Bin,label,email_subject,email_body,summary
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
13614,PD,3524302,a1G4z00000H4xcvEAB,2023-10-02T17:53:01,Overstock,Order Discrepancy,,1,0.529041,Incorrect,5,Returns,Need signature AC account,"Good afternoon, We have received an order fro...",**Customer:** 2057194105\n\n**Issue:** Unsign...
36424,PD,3618678,,2023-11-10T13:32:42,Account updates,Billing / Invoice,,1,0.496874,Incorrect,4,Account/Inquiry,Auto-Reply. We Have Received Your Request,"To whom it may concern, Your request has been...",**Summary**\n\nA customer reached out to Card...


In [19]:
test_documents = DataFrameLoader(test_documents_frame.reset_index(drop=False), page_content_column='summary').load()

## 10k test frame

In [20]:
# batch_result_file_uri = "JDB_experiments/summarization/prediction-model-2023-12-18T15:10:57.834767Z/000000000000.jsonl"

# batch_result = list(read_json_lines_from_gcs(batch_result_file_uri))

In [21]:
# batch_result_summaries = [
#     r.get('predictions', [{}])[0].get("content", "").strip() for r in batch_result]

In [22]:
# batch_result_metadata = list(read_json_lines_from_gcs(
#     "JDB_experiments/summarization_metadata.jsonl"
# ))

In [23]:
# batch_result_dataframe = pd.DataFrame.from_records(batch_result_metadata)
# batch_result_dataframe.loc[:, 'summary'] = pd.Series(batch_result_summaries)

In [24]:
# # Add label
# training_data = pd.read_excel(
#     f"gs://{PROJECT_BUCKET}/Last50KCases_withSubjectAndBody.xlsx")
# training_data.loc[:, 'email_subject'] = training_data.email_subject.fillna("N/A")

In [25]:
# batch_result_labels = training_data.loc[batch_result_dataframe.idx, LABEL_COLUMN]
# batch_result_dataframe.loc[:, 'label'] = batch_result_labels.tolist()

In [26]:
# train_10k_idx, test_2k_idx = get_idx(prefix=WRITE_PREFIX + "/summarization_idx")
# test_2k_idx.shape

In [27]:
# test_2k_documents_df = batch_result_dataframe.set_index('idx').loc[test_2k_idx, :].reset_index(drop=False)
# test_2k_documents_df.head(2)

In [28]:
# test_2k_documents = DataFrameLoader(test_2k_documents_df, 'summary').load()
# test_2k_documents[0]

## Prediction prompt

In [29]:
labels = get_possible_labels()
labels

['Order Processing',
 'Product Inquiry',
 'Account/Inquiry',
 'General Inquiry',
 'Returns',
 'Billing / Invoice',
 'Delivery',
 'Credits',
 'Order Discrepancy',
 'Pricing',
 'Program / Promotions']

In [54]:
#| export
EMAIL_LABEL_SEP = "|||"

LABEL_STR = """- Order Processing
- Product Inquiry
- Account/Inquiry
- General Inquiry
- Returns
- Billing / Invoice
- Delivery
- Credits
- Order Discrepancy
- Pricing
- Program / Promotions"""

PREDICTION_PROMPT_TEMPLATE = """\
Our customer service team wants to classify emails so they can be sent to the right support team.
Here are the labels they use;

--LABELS--
""" + LABEL_STR + """

Below are a series of emails that have already been labeled, use their example to identify what label the final email should get.
Your answer must be one of the options in the --LABELS-- list.
Return only the label from the above list that you chose.

--EMAILS--
{examples}
EMAIL: {email} """ + f"{EMAIL_LABEL_SEP} LABEL: "

PREDICTION_PROMPT = PromptTemplate.from_template(PREDICTION_PROMPT_TEMPLATE)

In [55]:
PREDICTION_PROMPT

PromptTemplate(input_variables=['email', 'examples'], template='Our customer service team wants to classify emails so they can be sent to the right support team.\nHere are the labels they use;\n\n--LABELS--\n- Order Processing\n- Product Inquiry\n- Account/Inquiry\n- General Inquiry\n- Returns\n- Billing / Invoice\n- Delivery\n- Credits\n- Order Discrepancy\n- Pricing\n- Program / Promotions\n\nBelow are a series of emails that have already been labeled, use their example to identify what label the final email should get.\nYour answer must be one of the options in the --LABELS-- list.\nReturn only the label from the above list that you chose.\n\n--EMAILS--\n{examples}\nEMAIL: {email} ||| LABEL: ')

In [56]:
#| export
def filter_examples(examples: List[Document], idx: int) -> List[Document]:
    return [e for e in examples if int(e.metadata.get('idx')) != int(idx)]

In [57]:
example_retrieved_documents = chroma.similarity_search(
    test_documents[0].page_content, 
    k=3)
len(example_retrieved_documents)

3

In [58]:
len(filter_examples(example_retrieved_documents, test_documents[0].metadata['idx']))

3

In [59]:
len(chroma.get()['ids'])

2400

In [60]:
#| export
def format_example(example: Document) -> str:
    return f"EMAIL: {example.page_content.strip()} {EMAIL_LABEL_SEP} LABEL: {example.metadata.get('label')}"


def make_prediction_prompt(
        email_summary: Document,
        chroma: Chroma,
        limit: int = None
) -> str:
    idx = email_summary.metadata.get('idx')
    k = 5
    prompt = None
    keep_stuffing = True
    max_k = len(chroma.get()['ids']) if limit is None else limit
    while keep_stuffing:
        examples = chroma.similarity_search(email_summary.page_content, k=k)
        examples = filter_examples(examples, idx)
        if len(examples) == 0:
            k += 5
            continue
        else:
            example_str = ""
            for e in examples:
                e_formatted = format_example(e)
                if len(example_str) == 0:
                    example_str = e_formatted
                else:
                    example_str = example_str + "\n" + e_formatted
                prompt = PREDICTION_PROMPT.format(
                    email=email_summary.page_content,
                    examples=example_str
                )
                if len(prompt) >= BISON_MAXIMUM_INPUT_TOKENS:
                    keep_stuffing = False
                    break
        k += 5
        if k >= max_k:
            keep_stuffing = False
    return prompt

In [62]:
example_prompt = make_prediction_prompt(
    test_documents[0], 
    chroma=chroma,
    limit=3
)
print(len(example_prompt))
print(example_prompt)

3256
Our customer service team wants to classify emails so they can be sent to the right support team.
Here are the labels they use;

--LABELS--
- Order Processing
- Product Inquiry
- Account/Inquiry
- General Inquiry
- Returns
- Billing / Invoice
- Delivery
- Credits
- Order Discrepancy
- Pricing
- Program / Promotions

Below are a series of emails that have already been labeled, use their example to identify what label the final email should get.
Your answer must be one of the options in the --LABELS-- list.
Return only the label from the above list that you chose.

--EMAILS--
EMAIL: **Customer:** 2057201348

**Issue:** Unsigned MRA (3901263910) received for two orders.

**Action Required:** Customer needs to provide a signed MRA within 48 hours to avoid the return of the case without credit.

**Business Function:**

- Tom Coppedge: Returns Lead, Warehouse Operations

**Additional Information:**

- The customer has already been contacted and informed about the situation. ||| LABEL: R

In [65]:
test_documents[0].metadata

{'idx': 13614,
 'BU': 'PD',
 'case_number': 3524302,
 'ACCOUNT_BUSINESS_UNIT__C': 'a1G4z00000H4xcvEAB',
 'received_at': '2023-10-02T17:53:01',
 'sfdc_subcategory': 'Overstock',
 'predicted_category': 'Order Discrepancy',
 'predicted_subcategory': nan,
 'record_type': 1,
 'probability': 0.5290409,
 'Accuracy_upd': 'Incorrect',
 'Bin': 5,
 'label': 'Returns',
 'email_subject': 'Need signature AC account',
 'email_body': 'Good afternoon,  We have received an order from customer 2057194105. They sent unsigned MRA 3901356789. Can you please reach out to the customer and let them know they have 48 hours to send a signed MRA or we will send back the case for no credit.  Thanks,   [cid:image001.png@01D9F52F.5AD273F0]  Tom Coppedge Returns Lead | Warehouse Operations 2840 Elm Pont Industrial Drive St. Charles, MO. 63301    _________________________________________________  This message is for the designated recipient only and may contain privileged, proprietary or otherwise private information.

In [66]:
predict(example_prompt)

 Returns

In [67]:
# example_10k_prompt = make_prediction_prompt(
#     test_2k_documents[0], 
#     chroma=chroma_10k,
#     limit=5
# )
# print(len(example_10k_prompt))
# print(example_10k_prompt)

In [68]:
# predict(example_10k_prompt)

Batch construct prompts

In [69]:
test_prompts = [
    make_prediction_prompt(s, chroma, 3) for s in tqdm(
        test_documents, 
        leave=False, 
        ncols=80)]

                                                                                

In [70]:
# test_2k_prompts = [
#     make_prediction_prompt(s, chroma_10k, 3) for s in tqdm(
#         test_2k_documents, 
#         leave=False, 
#         ncols=80)]

In [72]:
llm = VertexAI()

In [75]:
llm.batch(test_prompts[:2])

[' Returns', ' Billing / Invoice']

In [81]:
#| export
@quota_handler
def predict_batch(llm: VertexAI, prompts: List[str]) -> List[str]:
    return llm.batch(prompts)


def get_predictions(llm: VertexAI, prompts: List[str]) -> List[str]:
    pbar = tqdm(total=len(prompts), ncols=80, leave=False)
    predictions = []
    for batch in get_batches(iter(prompts), 5):
        batch_predictions = predict_batch(llm, batch)
        predictions.extend(batch_predictions)
        pbar.update(len(batch))
    pbar.close()
    return predictions

In [82]:
test_predictions = get_predictions(llm, test_prompts)

                                                                                

In [83]:
# test_predictions_10k = get_predictions(test_2k_prompts)

In [87]:
len(test_predictions)# , len(test_predictions_10k)

600

In [88]:
test_predictions[0]# , test_predictions_10k[0]

' Returns'

In [89]:
test_labels = [d.metadata.get('label') for d in test_documents]

In [93]:
pd.Series(test_predictions).str.strip().value_counts()

Order Processing        194
Account/Inquiry         120
Billing / Invoice        69
Order Discrepancy        68
Returns                  44
Product Inquiry          39
Delivery                 33
Credits                  14
General Inquiry           7
Pricing                   5
Program / Promotions      5
Contracts                 1
Fax Transmission          1
Name: count, dtype: int64

In [92]:
# pd.Series(test_predictions_10k).value_counts()

In [104]:
#| export
def write_predictions(
        predictions: List[str],
        labels: List[str],
        idx: List[str],
        prompts: List[str],
        emails: List[str],
        directory: Path,
        file_name: str = "predictions.csv"):
    pd.DataFrame(
        list(zip(
            predictions, 
            labels,
            idx,
            prompts,
            emails)),
        columns=['prediction', 'label', 'idx', 'prompt', 'email']
    ).to_csv(directory / file_name, index=False)

In [106]:
write_predictions(
    [s.strip() for s in test_predictions], 
    [d.metadata.get('label') for d in test_documents],
    [d.metadata.get('idx') for d in test_documents],
    test_prompts,
    test_documents_frame.apply(lambda row: f"-- SUBJECT --\n{row.email_subject}\n-- BODY --\n{row.email_body}", axis=1).values.tolist(),
    tejas_dir,
    'sample_predictions.csv')

In [96]:
# write_predictions(
#     test_predictions_10k, 
#     [d.metadata.get('label') for d in test_2k_documents],
#     [d.metadata.get('idx') for d in test_2k_documents],
#     data_dir,
#     'predictions_2k.csv')

In [108]:
#| hide
import nbdev; nbdev.nbdev_export()