In [1]:
import uuid
import os
import sys
from typing import Optional, List

from pydantic import Field
from dotenv import load_dotenv
from rich.pretty import pprint
from datetime import datetime
from zoneinfo import ZoneInfo

from zep_cloud.client import AsyncZep
from zep_cloud.types import Message
from zep_cloud.extractor import (
    ZepModel,
    ZepText,
    ZepNumber,
    ZepFloat,
    ZepRegex,
    ZepZipCode,
    ZepDateTime,
    ZepEmail,
    ZepPhoneNumber,
)

In [2]:
sys.path.append(os.path.join(os.getcwd(), "..", "..", "data"))

from chat_history_shoe_purchase import history

In [3]:
# Place your Zep API key in your environemnt or in a .env file
load_dotenv()

API_KEY = os.getenv("ZEP_API_KEY")

assert API_KEY is not None, "ZEP_API_KEY is not set"

In [4]:
client = AsyncZep(
    api_key=API_KEY,
)
user_id = uuid.uuid4().hex  # unique user id. can be any alphanum string

await client.user.add(
    user_id=user_id,
    email="user@example.com",
    first_name="Jane",
    last_name="Smith",
    metadata={"vip": "true"},
)

session_id = uuid.uuid4().hex  # unique session id. can be any alphanum string

# Create session associated with the above user
await client.memory.add_session(
    session_id=session_id, user_id=user_id, metadata={"foo": "bar"}
)

# Add Memory for session
for m in history[:-1]:
    await client.memory.add(session_id=session_id, messages=[Message(**m)])

In [5]:
# Check that the messages were added
msgs = await client.memory.get_session_messages(session_id)

pprint([(msg.role, msg.content) for msg in msgs.messages])

In [6]:
LAST_N = 10 # The number of most recent messages to consider for extraction

class SalesOrder(ZepModel):
    order_id: Optional[ZepNumber] = Field(description="The order id", default=None)
    order_items: Optional[ZepText] = Field(description="The item ordered", default=None)
    delivery_date: Optional[ZepDateTime] = Field(
        description="The date and time the order will be delivered", default=None
    )
    order_amount: Optional[ZepFloat] = Field(
        description="The ordered item's price", default=None
    )
    brand_preferences: Optional[ZepRegex] = Field(
        description="The customer's preferred brands. Comma-separated list of brands",
        default=None,
        pattern=r"\w+(, \w+)+",
    )
    shoe_size: Optional[ZepNumber] = Field(
        description="The customer's shoe size", default=None
    )
    customer_first_name: Optional[ZepText] = Field(
        description="The customer's first name. Capitalize appropriately.", default=None
    )
    customer_last_name: Optional[ZepText] = Field(
        description="The customer's last name. Capitalize appropriately.", default=None
    )
    customer_email: Optional[ZepEmail] = Field(
        description="The customer's email", default=None
    )
    customer_phone: Optional[ZepPhoneNumber] = Field(
        description="The customer's phone number", default=None
    )
    street_number: Optional[ZepNumber] = Field(
        description="The delivery street number", default=None
    )
    street_name: Optional[ZepText] = Field(
        description="The delivery street name", default=None
    )
    zip_code: Optional[ZepZipCode] = Field(
        description="The delivery zip code", default=None
    )
    order_currency: Optional[ZepRegex] = Field(
        description="The order currency: USD, GBP, or UNKNOWN",
        default=None,
        pattern=r"(UNKNOWN|USD|GBP)",
    )

In [7]:
def merge_models(
    existing_order: SalesOrder,
    new_order: SalesOrder,
    mutable_fields: Optional[List[str]] = None,
) -> SalesOrder:
    """Merge two ZepModels, overwriting existing fields with new values. We exclude unset, none, and defaults.
    All fields in existing_order are immutable by default, unless specified in mutable_fields.
    """
    if mutable_fields is None:
        mutable_fields = []

    existing_data = existing_order.model_dump(
        exclude_unset=True, exclude_none=True, exclude_defaults=True
    )
    new_data = new_order.model_dump(
        exclude_unset=True, exclude_none=True, exclude_defaults=True
    )

    for field in list(existing_data.keys()):
        if field not in mutable_fields:
            new_data.pop(field, None)

    return SalesOrder(
        **{
            **existing_data,
            **new_data,
        }
    )

In [8]:
# Extract the sales order information from the session memory
sales_order = await client.memory.extract(
    session_id,  # The unique identifier for the user's session
    SalesOrder,  # The model class to use for extracting the data
    last_n=LAST_N,  # The number of most recent messages to consider for extraction
    current_date_time=datetime.now(
        ZoneInfo("America/New_York")
    ),  # The current date and time in the specified timezone, will be used to resolve relative dates.
)
pprint(sales_order)

In [9]:
# Extract AND validate the sales order information. Doing so adds a validation step evaluating whether the extracted data matches provided dialog context.
# The validation step adds additional latency and may result in false negatives i.e. data present in the dialog context may not be extracted.
sales_order_validated = await client.memory.extract(
    session_id,  # The unique identifier for the user's session
    SalesOrder,  # The model class to use for extracting the data
    last_n=LAST_N,  # The number of most recent messages to consider for extraction
    current_date_time=datetime.now(
        ZoneInfo("America/New_York")
    ),  # The current date and time in the specified timezone, will be used to resolve relative dates.
    validate=True,
)
pprint(sales_order_validated)

In [10]:
# Add an additional message with the user's email address. We'll also infer the first name.
await client.memory.add(
    session_id=session_id,
    messages=[
        Message(role_type="user", content="My email address is mary@example.com")
    ],
)

sales_order_new = await client.memory.extract(
    session_id,
    SalesOrder,
    last_n=LAST_N,
    current_date_time=datetime.now(ZoneInfo("America/New_York")),
)

# Since we're using a rolling message window to extract facts, we need to merge the new facts with the existing ones.
sales_order = merge_models(sales_order, sales_order_new)

pprint(sales_order)

In [11]:
# Add an additional message with the user's street address and zip code.
await client.memory.add(
    session_id=session_id,
    messages=[
        Message(
            role_type="assistant",
            content="Great, what is your street address and zip code?",
        ),
        Message(role_type="user", content="987 Main St, zip code is 12345"),
    ],
)

sales_order_new = await client.memory.extract(
    session_id,
    SalesOrder,
    last_n=LAST_N,
    current_date_time=datetime.now(ZoneInfo("America/New_York")),
)

sales_order = merge_models(sales_order, sales_order_new)

pprint(sales_order)

In [12]:
# Add an additional message with a relative date. The relative date will be converted to an absolute date.
await client.memory.add(
    session_id=session_id,
    messages=[
        Message(
            role_type="assistant",
            content="Fantastic! We'll deliver your order tomorrow at 1pm!",
        ),
    ],
)

sales_order_new = await client.memory.extract(
    session_id,
    SalesOrder,
    last_n=LAST_N,
    current_date_time=datetime.now(ZoneInfo("America/New_York")),
)

sales_order = merge_models(sales_order, sales_order_new)

pprint(sales_order)