In [60]:
from zep_cloud import ZepDataType
import typing
from dotenv import load_dotenv
from pydantic import BaseModel, Field, WithJsonSchema, BeforeValidator, AfterValidator, PlainSerializer
from typing_extensions import Annotated
from typing import Optional
from zep_cloud.client import AsyncZep
import asyncio
import os

load_dotenv()

class ZepBaseText(BaseModel):
    description: typing.Optional[str] = None
    type: ZepDataType = "ZepText"

ZepText = Annotated[
    str | None,
    Field(default=None, zeptype="ZepText"),
    WithJsonSchema(ZepBaseText.model_json_schema(), mode='serialization'),
]

class ZepBaseNumber(BaseModel):
    description: typing.Optional[str] = None
    type: ZepDataType = "ZepFloat"

ZepNumber = Annotated[
    int | None,
    Field(default=None, zeptype="ZepNumber"),
    WithJsonSchema(ZepBaseNumber.model_json_schema(), mode='serialization'),
]

class ZepBaseFloat(BaseModel):
    description: typing.Optional[str] = None
    type: ZepDataType = "ZepFloat"

ZepFloat = Annotated[
    float | None,
    Field(default=None, zeptype="ZepFloat"),
    WithJsonSchema(ZepBaseNumber.model_json_schema(), mode='serialization'),
]

class ZepModel(BaseModel):
    _client: Optional[AsyncZep] = None

    def __init__(self, **data):
        super().__init__(**data)
        self._client = data.get('_client')

    @property
    def client(self):
        return self._client
    
    def convert_value(self, key, value):
        field_type = self.__annotations__[key]
        if field_type == ZepText:
            return str(value)
        elif field_type == ZepNumber:
            return int(value)
        elif field_type == ZepFloat:
            return float(value) 
        else:
            return value

    
    async def extract(self, session_id: str, last_n_messages: int = 100):
        model_schema = self.schema_json()
        extractor_result = await self._client.memory.extract_session_data(session_id, last_n_messages=last_n_messages, model_schema=model_schema)
        for key, value in extractor_result.items():
            if hasattr(self, key):
                converted_value = self.convert_value(key, value)
                print(f'Type of {key}: {type(converted_value)}')
                setattr(self, key, converted_value)
                
        return self

API_KEY = os.environ.get("ZEP_API_KEY")

client = AsyncZep(
        api_key=API_KEY,
        base_url="http://localhost:8000/api/v2"
    )

class ShoeInfoModel(ZepModel):
    shoes: ZepText = Field(description="what activity are the shoes used for?" )
    shoe_size: ZepNumber = Field(description="what is the user's shoe size?" )
    price: ZepFloat = Field(description="How much was spent on shoes?" )
    
t = ShoeInfoModel(_client=client)

task = asyncio.create_task(t.extract("a2eb841bc24245fdb0901b0a87a865cd", 100))

result = await asyncio.gather(task)


print("Result:", result[0].shoe_size, type(result[0].shoe_size))
print("Result:", result[0].price,type(result[0].price))
print("Result:", result[0].shoes)


Type of price: <class 'float'>
Type of shoe_size: <class 'int'>
Type of shoes: <class 'str'>
Result: 10 <class 'int'>
Result: 129.99 <class 'float'>
Result: running
