In [10]:
from abc import ABC, abstractmethod
from types import TracebackType
from typing import AsyncIterator, Generic, List, Optional, Type, TypedDict, TypeVar

TSend = TypeVar('TSend', contravariant=True)
TYield = TypeVar('TYield', covariant=True)

class AsyncGenerator(ABC, AsyncIterator[TYield], Generic[TYield, TSend]):
    @abstractmethod
    def __aiter__(self) -> AsyncIterator[TYield]:
        return self
    
    @abstractmethod
    async def __anext__(self) -> TYield:  # throws: StopAsyncIteration, ...
        return await self.asend(None)

    @abstractmethod
    async def asend(
        self,
        input: Optional[TSend]
    ) -> TYield:  # throws: StopAsyncIteration, ...
        ...

    @abstractmethod
    async def athrow(
        self,
        exc_type: Type[BaseException],
        exc_value: Optional[BaseException] = None,
        traceback: Optional[TracebackType] = None,
    ) -> Optional[TYield]:  # throws: exc_type, StopAsyncIteration, ...
        ...
    
    @abstractmethod
    async def aclose(
        self
    ) -> None:  # throws RuntimeError, ...
        try:
            await self.athrow(GeneratorExit)
        except (GeneratorExit, StopAsyncIteration):
            pass
        else:
            raise RuntimeError("...")

In [17]:
from datetime import datetime
from enum import Enum

from docarray import BaseDoc
from docarray.typing import NdArray
from pydantic import Field

class ChatMLRole(str, Enum):
    System = "system"
    User = "user"
    Assistant = "assistant"

class ChatMLDict(TypedDict):
    role: ChatMLRole
    content: str

class ChatMLMessage(BaseDoc):
    role: ChatMLRole
    content: str

class MessageMetadata(BaseDoc):
    timestamp: datetime = Field(default_factory=datetime.now)

class Message(ChatMLMessage):
    metadata: MessageMetadata = Field(default_factory=MessageMetadata)

class User(Message):
    role: ChatMLRole = ChatMLRole.User

class Assistant(Message):
    role: ChatMLRole = ChatMLRole.Assistant

class System(Message):
    role: ChatMLRole = ChatMLRole.System


TSignal = TypeVar('TSignal')
class Signal(Generic[TSignal], MessageMetadata):
    content: TSignal
    needs_input: bool = False
    done: bool = False

class Start(Signal[TSignal]):
    pass

class GetInput(Signal[TSignal]):
    needs_input: bool = True

class Result(Signal[TSignal]):
    done: bool = True


OPENAI_EMBEDDING_DIMS: int = 1536
class EmbeddingMessage(Message):
    embedding: NdArray[OPENAI_EMBEDDING_DIMS] = Field(
        dims=OPENAI_EMBEDDING_DIMS,
        is_embedding=True,
    )

TypeError: Cannot create a consistent method resolution
order (MRO) for bases Generic, MessageMetadata

In [12]:
import multiprocessing

from docarray import DocList
from docarray.index.abstract import BaseDocIndex
from docarray.index.backends.weaviate import WeaviateDocumentIndex
from docarray.index.backends.weaviate import EmbeddedOptions

cpu_count: int = multiprocessing.cpu_count()

dbconfig = WeaviateDocumentIndex.DBConfig(
    embedded_options=EmbeddedOptions(
        persistence_data_path="./.turbo_chat/weaviate-embedded"
    )
)

batch_config = {
    "batch_size": 20,
    "dynamic": True,
    "timeout_retries": 3,
    "num_workers": cpu_count // 2,
}

runtime_config = WeaviateDocumentIndex.RuntimeConfig(batch_config=batch_config)

class BaseMemory(DocList[Message]):
    index: Optional[BaseDocIndex[Message]] = None

    def sorted(self) -> "BaseMemory":
        return sorted(
            self,
            key=lambda doc: doc.metadata.timestamp,
            reverse=True,
        )

    async def process(self, **kwargs) -> None:
        ...

    @abstractmethod
    def to_prompt(self, **kwargs) -> List[ChatMLDict]:
        raise NotImplementedError


class Memory(BaseMemory):
    def to_prompt(self, **kwargs) -> List[ChatMLDict]:
        return [
            ChatMLDict(role=message.role, content=message.content)
            for message in self.sorted()
        ]


class WeaviateMemory(BaseMemory):
    index: WeaviateDocumentIndex[Message] = Field(
        default_factory=lambda: WeaviateDocumentIndex[Message](db_config=dbconfig)
    )

    def __init__(self, **data):
        super().__init__(**data)
        self.index.configure(runtime_config)

    def to_prompt(self, **kwargs) -> List[ChatMLDict]:
        raise NotImplementedError

In [13]:
from functools import wraps

class TurboGenerator(AsyncGenerator):
    def __init__(self, func, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
        self.func = func
        self._gen = func(*args, **kwargs)

    def __aiter__(self) -> AsyncIterator[TYield]:
        return self._gen
    
    async def __anext__(self) -> TYield:  # throws: StopAsyncIteration, ...
        return await self.asend(None)

    async def asend(
        self,
        input: Optional[TSend] = None,
    ) -> TYield:  # throws: StopAsyncIteration, ...
        return await self._gen.asend(input)

    async def athrow(
        self,
        exc_type: Type[BaseException],
        exc_value: Optional[BaseException] = None,
        traceback: Optional[TracebackType] = None,
    ) -> Optional[TYield]:  # throws: exc_type, StopAsyncIteration, ...
        print(exc_type, exc_value, traceback)
    
    async def aclose(
        self
    ) -> None:  # throws RuntimeError, ...
        try:
            await self.athrow(GeneratorExit)
        except (GeneratorExit, StopAsyncIteration):
            pass
        else:
            raise RuntimeError("...")

class turbo:
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            turbo_gen = TurboGenerator(func, *self.args, **self.kwargs)
            return turbo_gen

        return wrapper

In [20]:
@turbo()
async def turbo_generator():
    for i in range(10):
        print(i)
        yield i

gen = turbo_generator()
await gen.asend(None)

0


0