# Email Classification: Building Blocks

## Env set up

In [0]:
!python3 --version

In [0]:
%pip install -U -qqq langgraph langgraph-prebuilt langgraph-sdk langgraph-checkpoint-sqlite langsmith langchain-community langchain-core langchain-openai notebook langchain-tavily wikipedia trustcall langgraph-cli[inmem] transformers

%pip install -U -qqq databricks-agents mlflow-skinny[databricks] databricks-langchain
dbutils.library.restartPython()

## LLM set up

In [0]:
## Variable containing the LLM Endpoint to use
llm_endpoint = 'databricks-gpt-5-1'
llm_endpoint = 'databricks-claude-sonnet-4-5'

## Example data

In [0]:
## Creating a dataframe with examples
columns = ["id", "email_body"]
examples_ = [
  (1, "Dear sir or madam, I would like to get my residence address updated to: v. bella 1, Milano (MI), Italia"),

  (2, "Hello, I'd like to get an updated delivery date for shipment ID-45892171, given that I didn't receive it yet. Thank you"),

  (3, "Understood, I will wait a couple of days thanks"),

  (4, "Hi, the product arrived broken. I want a refund.")
]
df_example = spark.createDataFrame(examples_, columns)
display(df_example)

## Basic Architecture: single classifier node

Entry-level architecture consisting of a single node with prompt engineering and structured output.

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Graph State
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages, AnyMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

## Set up the agent state to concatenate messages in memory
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    label: str
    rationale: str
    next_steps: str

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Output Schema
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

## Set up the output schema for the classification
from pydantic import BaseModel, Field
from typing import Literal

class LabelOutput(BaseModel):
  label: Literal[
      "Claim",
      "Profile Update",
      "Communication Unsubscription",
      "General Enquiry",
      "Spam",
      "Others"] = Field(
    description="Label assigned to the email",
    default="Others")
  rationale: str = Field(description="Reasoning behind the label choice")
  next_steps: str = Field(description="Recommended action items based on the customer's email")

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Base LLM
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from databricks_langchain import ChatDatabricks
model = ChatDatabricks(endpoint = llm_endpoint, temperature=0) 

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Classification Tool
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langchain_core.tools import tool
from langgraph.graph import MessagesState

# Classify through prompt
MODEL_SYSTEM_MESSAGE = ("""
    You are an email dispatcher assistant.
    Based on some contextual information, your tasks are: 
        (1) assign a label to each customer's email.
        (2) provide an summarized explanation of why you assigned such label.
        (3) recommend next steps, based on the context.

    Here are the available labels you can choose from, together with a description of when to use them:
    - `Claim`: the customer is formally requesting a complaint or a refund.
    - `Profile Update`: the customer is requesting to update or modify their profile information (such as address, ...).
    - `Communication Unsubscription`: the customer is requesting to stop receiving communications or newsletter about products, offers, etc.
    - `General Enquiry`: the customer is asking for general information about their orders, the company, shipping policies, return policies, reimbursement terms, etc.
    - `Spam`: the email is suspected being a spam email, not referring to any plausible customer's request and is not a plausible follow up to previous requests.
    - `Others`: in case the email cannot be related to any of the previous labels.

    Here is the contextual information.

    The body email sent by the customer:
    <email_body>
    {email_body}
    </email_body>

    Now classify the email, explain your reasoning and recommend next steps.
""")

## Classification definition
def classify_email(state: AgentState) -> AgentState:
    
    """  Classify a customer email into one of the predefined labels, provide reasoning and next steps.  """
    
    # Extract email body from message list in the state
    email_body = state["messages"][-1].content
    
    # Pass data to the prompt
    full_prompt = MODEL_SYSTEM_MESSAGE.format(
        email_body=email_body,
    )

    # Invoke
    structured_model = model.with_structured_output(LabelOutput)
    result: LabelOutput = structured_model.invoke(
        [full_prompt] + state["messages"]
    )
    
    return {
        "messages": state["messages"] + [
            AIMessage(content = ("""
                Label: {Label}\n
                Rationale: {Rationale}\n
                Next steps: {NextSteps}
                """).format(Label=result.label, Rationale=result.rationale, NextSteps=result.next_steps))
        ],
        "label": result.label,
        "rationale": result.rationale,
        "next_steps": result.next_steps,
    }

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Graph
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from langgraph. graph import START, StateGraph, END
from langgraph.prebuilt import tools_condition # Check if the model's output is a tool call
from langgraph.prebuilt import ToolNode
from IPython.display import Image, display
from langgraph.checkpoint.memory import MemorySaver

# Graph
builder = StateGraph(AgentState)

# Nodes
builder.add_node("classification", classify_email)

# Define edges: these determine how the control flow moves
builder.add_edge(START, "classification")
builder.add_edge("classification", END)
memory = MemorySaver()
basic_email_classifier = builder.compile(checkpointer=memory)

display(Image(basic_email_classifier.get_graph(xray=True).draw_mermaid_png()))

In [0]:
## Testing
from pyspark.sql.functions import *
from langchain_core.messages import HumanMessage

# Extract test case from the dataframe with examples
config_ = {"configurable": {"thread_id": "1"}}
id_ = 4
example_row = df_example.filter(col("ID") == id_).first()
message_ = example_row.email_body

# Invoke the Agent
request = [
  HumanMessage(content = message_)
]
messages = basic_email_classifier.invoke({"messages": request}, config_)

for m in messages['messages']:
    m.pretty_print()

## ReAct Version - 2
Evolving the previous architecture where the classification becomes a tool in a tool node of a ReAct agent.

**_Note_**: example 2 triggers an indefinite loop if we use OpenAI gpt-5-1 but works fine if we use claude-sonnet-4-5

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Graph State
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages, AnyMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

## Set up the agent state to concatenate messages in memory
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    label: str
    rationale: str
    next_steps: str

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Output Schema
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

## Set up the output schema for the classification
from pydantic import BaseModel, Field
from typing import Literal

class LabelOutput(BaseModel):
  label: Literal[
      "Claim",
      "Profile Update",
      "Communication Unsubscription",
      "General Enquiry",
      "Spam",
      "Others"] = Field(
    description="Label assigned to the email",
    default="Others")
  rationale: str = Field(description="Reasoning behind the label choice")
  next_steps: str = Field(description="Recommended action items based on the customer's email")

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Base LLM
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from databricks_langchain import ChatDatabricks
model = ChatDatabricks(endpoint = llm_endpoint, temperature=0) 

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Classification Tool
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langchain_core.tools import tool
from langgraph.graph import MessagesState

# Classify through prompt
MODEL_SYSTEM_MESSAGE = ("""
    You are an email dispatcher assistant.
    Based on some contextual information, your tasks are: 
        (1) assign a label to each customer's email.
        (2) provide an summarized explanation of why you assigned such label.
        (3) recommend next steps, based on the context.

    Here are the available labels you can choose from, together with a description of when to use them:
    - `Claim`: the customer is formally requesting a complaint or a refund.
    - `Profile Update`: the customer is requesting to update or modify their profile information (such as address, ...).
    - `Communication Unsubscription`: the customer is requesting to stop receiving communications or newsletter about products, offers, etc.
    - `General Enquiry`: the customer is asking for general information about their orders, the company, shipping policies, return policies, reimbursement terms, etc.
    - `Spam`: the email is suspected being a spam email, not referring to any plausible customer's request and is not a plausible follow up to previous requests.
    - `Others`: in case the email cannot be related to any of the previous labels.

    Here is the contextual information.

    The body email sent by the customer:
    <email_body>
    {email_body}
    </email_body>

    Now classify the email, explain your reasoning and recommend next steps.
""")

## Classification definition
def classify_email(email_body: str) -> LabelOutput:
    
    """  Classify a customer email into one of the predefined labels, provide reasoning and next steps.  """
        
    # Pass data to the prompt
    full_prompt = MODEL_SYSTEM_MESSAGE.format(
        email_body=email_body,
    )

    # Invoke
    structured_model = model.with_structured_output(LabelOutput)
    result: LabelOutput = structured_model.invoke(
        [full_prompt]
    )
    
    return {
        "label": result.label,
        "rationale": result.rationale,
        "next_steps": result.next_steps,
        "tool_call_status": "Complete!"
    }

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Create Tool Binding
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
tools = [classify_email]
model_with_tools = model.bind_tools(tools)

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# Reasoning Node (Assistant)
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langgraph.graph import MessagesState
from langchain_core.messages import HumanMessage, SystemMessage

# System message
sys_msg = SystemMessage(content="""
    You are a customer support assistant, tasked with classifying incoming emails and suggesting next actions to improve customer relationship.\n
    Use the available tools to perform the classification.\n
    If you feel unsure about the classification, DO NOT keep using the same tools but suggest asking for more information.\n
    Once you've obtained a valid label for the email STOP.""")

# Reasoning Node
def assistant(state: AgentState):
  result = model_with_tools.invoke([sys_msg] + state["messages"])
  return {"messages": [result]}

In [0]:
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 
# ReAct Graph
#### #### #### #### #### #### #### #### #### #### #### #### #### #### #### #### 

from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition # Check if the model's output is a tool call
from langgraph.prebuilt import ToolNode
from IPython.display import Image, display

# Graph
builder = StateGraph(AgentState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
builder.add_edge("tools", "assistant")
react_email_classifier = builder.compile()

# Display the Graph
display(Image(react_email_classifier.get_graph().draw_mermaid_png()))

In [0]:
## Testing
from pyspark.sql.functions import *
from langchain_core.messages import HumanMessage
import mlflow

# Extract test case from the dataframe with examples
id_ = 4
config_ = {"configurable": {"thread_id": str(id_)}}
example_row = df_example.filter(col("ID") == id_).first()
message_ = example_row.email_body

# Invoke the Agent
mlflow.langchain.autolog()
with mlflow.start_run(run_name="React Email Classifier Test"):
  request = [
    HumanMessage(content = message_)
  ]
  messages = react_email_classifier.invoke({"messages": request}, config_)

  for m in messages['messages']:
      m.pretty_print()