diff --git a/chatkit/_compat.py b/chatkit/_compat.py new file mode 100644 index 0000000..5123c9c --- /dev/null +++ b/chatkit/_compat.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +try: + from enum import StrEnum +except ImportError: # Python < 3.11 + from enum import Enum + + class StrEnum(str, Enum): + """Minimal StrEnum compatibility for Python < 3.11.""" + + +try: + from typing import assert_never +except ImportError: # Python < 3.11 + from typing import NoReturn + + def assert_never(arg: NoReturn) -> NoReturn: + raise AssertionError(f"Expected code path to be unreachable, got: {arg!r}") + + +__all__ = ("StrEnum", "assert_never") diff --git a/chatkit/agents.py b/chatkit/agents.py index 349904a..23af5ae 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -10,7 +10,6 @@ Generic, Sequence, TypeVar, - assert_never, cast, ) @@ -39,6 +38,7 @@ ) from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter +from ._compat import assert_never from .server import stream_widget from .store import Store, StoreItemType from .types import ( diff --git a/chatkit/errors.py b/chatkit/errors.py index f5f1d23..76886b5 100644 --- a/chatkit/errors.py +++ b/chatkit/errors.py @@ -1,5 +1,6 @@ from abc import ABC -from enum import StrEnum + +from ._compat import StrEnum # Not a closed enum, new error codes can and will be added as needed diff --git a/chatkit/server.py b/chatkit/server.py index 5963ff8..616d7e3 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -3,14 +3,7 @@ from collections.abc import AsyncIterator from contextlib import contextmanager from datetime import datetime -from typing import ( - Any, - AsyncGenerator, - AsyncIterable, - Callable, - Generic, - assert_never, -) +from typing import Any, AsyncGenerator, AsyncIterable, Callable, Generic import agents from agents.models.chatcmpl_helpers import ( @@ -24,6 +17,7 @@ from chatkit.errors import CustomStreamError, StreamError +from ._compat import assert_never from .logger import logger from .store import AttachmentStore, Store, StoreItemType, default_generate_id from .types import ( diff --git a/tests/helpers/mock_widget.py b/tests/helpers/mock_widget.py index c6ef892..41a071a 100644 --- a/tests/helpers/mock_widget.py +++ b/tests/helpers/mock_widget.py @@ -2,12 +2,13 @@ import re import uuid from datetime import datetime, timedelta -from typing import Annotated, Any, AsyncIterator, Callable, Literal, assert_never +from typing import Annotated, Any, AsyncIterator, Callable, Literal from agents import Agent, Runner from anyio import sleep from pydantic import BaseModel, Field, TypeAdapter +from chatkit._compat import assert_never from chatkit.actions import Action, ActionConfig from chatkit.types import ThreadStreamEvent from chatkit.widgets import (