# Email classificator Agent - v1

This notebook builds a basic ReAct agent architecture with memory, that classifies email passed in the prompt.

## Env Set up

In [0]:
!python3 --version

In [0]:
%pip install -U -qqqq langgraph langgraph-prebuilt langgraph-sdk langgraph-checkpoint-sqlite langsmith langchain-community langchain-core langchain-openai notebook langchain-tavily wikipedia trustcall langgraph-cli[inmem] transformers

%pip install -U -qqqq databricks-agents mlflow-skinny[databricks] databricks-langchain
dbutils.library.restartPython()

## Helper function for examples set up

In [0]:
from pyspark.sql.functions import *
def format_email_classification_request(email_sender:str, email_subject:str, email_body:str) -> str:

  """ NOT NECESSARY - The output of this function is a string which can be used to send a classification request to the Agent """
  
  request_template = "We received from: `{sender}` the following email. Here's the email subject: `{subject}` and here's the email body: `{body}`"

  return request_template.format(sender = email_sender, subject = email_subject, body = email_body)

In [0]:
## Creating a dataframe with examples
columns = ["ID", "sender", "subject", "body"]
examples_ = [
  (1, "mario.rossi@gmail.com", "Address update", "Dear sir or madam, I would like to get my residence address updated to: v. bella 1, Milano (MI), Italia"),

  (2, "alberto.rossi@gmail.com", "Shipment status", "Hello, I'd like to get an updated delivery date for shipment ID-45892171, given that I didn't receive it yet. Thank you"),

  (3, "alberto.rossi@gmail.com", "RE:Shipment status", "Understood, I will wait a couple of days thanks")
]
df_example = spark.createDataFrame(examples_, columns)
display(df_example)

In [0]:
example_row = df_example.filter(col("ID") == 2).first()
print(format_email_classification_request(example_row.sender, example_row.subject, example_row.body))

## Classifier Agent: Basic version

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Graph State
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages, AnyMessage

## Set up the agent state to concatenate messages in memory
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Output Schema
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

## Set up the output schema for the classification
from pydantic import BaseModel, Field
from typing import Literal

class LabelOutput(BaseModel):
  label: Literal[
      "Claim",
      "Profile Update",
      "Communication Unsubscription",
      "General Enquiry",
      "Spam",
      "Others"] = Field(
    description="Label assigned to the email",
    default="Others")
  retionale: str = Field(description="Reasoning behind the label choice")

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Base LLM
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from databricks_langchain import ChatDatabricks
model = ChatDatabricks(endpoint = "databricks-gpt-5-1", temperature=0) 


In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Classification Tool
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langchain_core.tools import tool

# Classify through prompt
MODEL_SYSTEM_MESSAGE = """You are designed to assign a label to each customer's email, based on some contextual information.

Here are the available labels you can choose from, together with a description of when to use them:
- `Claim`: the customer is formally requesting a complaint or a refund.
- `Profile Update`: the customer is requesting to update or modify their profile information (such as address, ...).
- `Communication Unsubscription`: the customer is requesting to stop receiving communications or newsletter about products, offers, etc.
- `General Enquiry`: the customer is asking for general information about their orders, the company, shipping policies, return policies, reimbursement terms, etc.
- `Spam`: the email is suspected being a spam email, not referring to any plausible customer's request and is not a plausible follow up to previous requests.
- `Others`: in case the email cannot be related to any of the previous labels.

Here is the contextual information.
The customer's email sender:
<email_sender>
{email_sender}
</email_sender>

The subject of the email sent by the customer:
<email_subject>
{email_subject}
</email_subject>

The body email sent by the customer:
<email_body>
{email_body}
</email_body>

Now classify the email and explain your reasoning.
"""

## Tool definition
@tool
def classify_email_tool(email_sender: str, email_subject: str, email_body: str, ) -> LabelOutput:
    
    """  Classify a customer email into one of the predefined labels and provide reasoning.  """
    
    # Pass data to the prompt
    full_prompt = MODEL_SYSTEM_MESSAGE.format(
        email_sender=email_sender,
        email_subject=email_subject,
        email_body=email_body,
    )

    # Invoke
    structured_model = model.with_structured_output(LabelOutput)
    return structured_model.invoke(full_prompt)

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Combine Tools together
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

tools = [classify_email_tool]
llm_with_tools = model.bind_tools(tools, parallel_tool_calls=False)

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Assistant Node
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from langchain_core.messages import SystemMessage

# System message
sys_msg = SystemMessage(content = (
    "You are a helpful assistant tasked with handling emails and classify them.\n"
    "- Use tools when needed.\n"
    "- Once you have successfully called `classify_email_tool` and received a result, you MUST NOT call any more tools.\n"
    "- Respond in natural language summarizing the label and reasoning, then stop.")
)

## Reasoning Node
def assistant(state: AgentState) -> AgentState:
    ai_msg = llm_with_tools.invoke([sys_msg] + state["messages"])
    return {"messages": [ai_msg]}

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Graph
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition # Check if the model's output is a tool call
from langgraph.prebuilt import ToolNode
from IPython.display import Image, display

# Graph
builder = StateGraph(AgentState)

# Nodes
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
builder.add_edge("tools", "assistant")
react_email_classifier = builder.compile()

display(Image(react_email_classifier.get_graph(xray=True).draw_mermaid_png()))

In [0]:
## Testing
from langchain_core.messages import HumanMessage

# Extract test case from the dataframe with examples
# example_row = df_example.filter(col("ID") == 2).first()
# message_ = format_email_classification_request(example_row.sender, example_row.subject, example_row.body)
message_ = "Hello"
message_ = (
    "email_sender: alice@example.com\n"
    "email_subject: Refund for damaged item\n"
    "email_body: Hi, the product arrived broken. I want a refund."
)

# Invoke the Agent
messages = [
  HumanMessage(content = message_)
]
messages = react_email_classifier.invoke({"messages": messages})

for m in messages['messages']:
    m.pretty_print()