Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 68 additions & 49 deletions src/zep_cloud/external_clients/memory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import datetime
import json
import typing
from packaging import version

import pydantic

from zep_cloud.core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from zep_cloud.extractor.models import ZepModel
from zep_cloud.memory.client import (
AsyncMemoryClient as AsyncBaseMemoryClient,
)
from zep_cloud.memory.client import (
MemoryClient as BaseMemoryClient,
)

if typing.TYPE_CHECKING:
from zep_cloud.extractor.models import ZepModel

MIN_PYDANTIC_VERSION = "2.0"


class MemoryClient(BaseMemoryClient):
def __init__(self, *, client_wrapper: SyncClientWrapper):
Expand All @@ -19,7 +26,7 @@ def __init__(self, *, client_wrapper: SyncClientWrapper):
def extract(
self,
session_id: str,
model: ZepModel,
model: "ZepModel",
current_date_time: typing.Optional[datetime.datetime] = None,
last_n: int = 4,
validate: bool = False,
Expand Down Expand Up @@ -65,16 +72,22 @@ class CustomerInfo(ZepModel):

print(customer_data.name) # Access extracted and validated customer name
"""

if version.parse(pydantic.VERSION) < version.parse(MIN_PYDANTIC_VERSION):
raise RuntimeError(
f"Pydantic version {MIN_PYDANTIC_VERSION} or greater is required."
)

model_schema = json.dumps(model.model_json_schema())

result = self.extract_data(
session_id=session_id,
model_schema=model_schema,
validate=validate,
last_n=last_n,
current_date_time=current_date_time.isoformat()
if current_date_time
else None,
current_date_time=(
current_date_time.isoformat() if current_date_time else None
),
)

return model.model_validate(result)
Expand All @@ -87,62 +100,68 @@ def __init__(self, *, client_wrapper: AsyncClientWrapper):
async def extract(
self,
session_id: str,
model: ZepModel,
model: "ZepModel",
current_date_time: typing.Optional[datetime.datetime] = None,
last_n: int = 4,
validate: bool = False,
):
"""Extracts structured data from a session based on a ZepModel schema.
This method retrieves data based on a given model and session details.
It then returns the extracted and validated data as an instance of the given ZepModel.

Parameters
----------
session_id: str
Session ID.
model: ZepModel
An instance of a ZepModel subclass defining the expected data structure and field types.
current_date_time: typing.Optional[datetime.datetime]
Your current date and time in ISO 8601 format including timezone.
This is used for determining relative dates.
last_n: typing.Optional[int]
The number of messages in the chat history from which to extract data.
validate: typing.Optional[bool]
Validate that the extracted data is present in the dialog and correct per the field description.
Mitigates hallucination, but is slower and may result in false negatives.

Returns
-------
ZepModel: An instance of the provided ZepModel subclass populated with the
extracted and validated data.

Examples
--------
class CustomerInfo(ZepModel):
name: Optional[ZepText] = Field(description="Customer name", default=None)
name: Optional[ZepEmail] = Field(description="Customer email", default=None)
signup_date: Optional[ZepDate] = Field(description="Customer Sign up date", default=None)

client = AsyncMemoryClient(...)

customer_data = await client.memory.extract(
session_id="session123",
model=CustomerInfo(),
current_date_time=datetime.datetime.now(), # Filter data up to now
)

print(customer_data.name) # Access extracted and validated customer name
"""
This method retrieves data based on a given model and session details.
It then returns the extracted and validated data as an instance of the given ZepModel.

Parameters
----------
session_id: str
Session ID.
model: ZepModel
An instance of a ZepModel subclass defining the expected data structure and field types.
current_date_time: typing.Optional[datetime.datetime]
Your current date and time in ISO 8601 format including timezone.
This is used for determining relative dates.
last_n: typing.Optional[int]
The number of messages in the chat history from which to extract data.
validate: typing.Optional[bool]
Validate that the extracted data is present in the dialog and correct per the field description.
Mitigates hallucination, but is slower and may result in false negatives.

Returns
-------
ZepModel: An instance of the provided ZepModel subclass populated with the
extracted and validated data.

Examples
--------
class CustomerInfo(ZepModel):
name: Optional[ZepText] = Field(description="Customer name", default=None)
name: Optional[ZepEmail] = Field(description="Customer email", default=None)
signup_date: Optional[ZepDate] = Field(description="Customer Sign up date", default=None)

client = AsyncMemoryClient(...)

customer_data = await client.memory.extract(
session_id="session123",
model=CustomerInfo(),
current_date_time=datetime.datetime.now(), # Filter data up to now
)

print(customer_data.name) # Access extracted and validated customer name
"""

if version.parse(pydantic.VERSION) < version.parse(MIN_PYDANTIC_VERSION):
raise RuntimeError(
f"Pydantic version {MIN_PYDANTIC_VERSION} or greater is required."
)

model_schema = json.dumps(model.model_json_schema())

result = await self.extract_data(
session_id=session_id,
model_schema=model_schema,
validate=validate,
last_n=last_n,
current_date_time=current_date_time.isoformat()
if current_date_time
else None,
current_date_time=(
current_date_time.isoformat() if current_date_time else None
),
)

return model.model_validate(result)
49 changes: 37 additions & 12 deletions src/zep_cloud/extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
from zep_cloud.extractor.models import (
ZepModel,
ZepText,
ZepNumber,
ZepFloat,
ZepRegex,
ZepZipCode,
ZepDate,
ZepDateTime,
ZepEmail,
ZepPhoneNumber,
)
# mypy: disable-error-code=no-redef

from typing import Type

# Zep Extraction requires Pydantic v2. If v2 is not installed, catch the error
# and set the variables to PydanticV2Required


class PydanticV2Required:
def __init__(self, *args, **kwargs):
raise RuntimeError("Pydantic v2 is required to use this class.")


try:
from zep_cloud.extractor.models import (
ZepModel,
ZepText,
ZepNumber,
ZepFloat,
ZepRegex,
ZepZipCode,
ZepDate,
ZepDateTime,
ZepEmail,
ZepPhoneNumber,
)
except ImportError:
ZepModel: Type = PydanticV2Required
ZepText: Type = PydanticV2Required
ZepNumber: Type = PydanticV2Required
ZepFloat: Type = PydanticV2Required
ZepRegex: Type = PydanticV2Required
ZepZipCode: Type = PydanticV2Required
ZepDate: Type = PydanticV2Required
ZepDateTime: Type = PydanticV2Required
ZepEmail: Type = PydanticV2Required
ZepPhoneNumber: Type = PydanticV2Required

__all__ = [
"ZepModel",
Expand Down