diff --git a/kilroy_module_client_py_sdk/src/kilroy_module_client_py_sdk/client.py b/kilroy_module_client_py_sdk/src/kilroy_module_client_py_sdk/client.py index c7f752a..3881503 100644 --- a/kilroy_module_client_py_sdk/src/kilroy_module_client_py_sdk/client.py +++ b/kilroy_module_client_py_sdk/src/kilroy_module_client_py_sdk/client.py @@ -367,16 +367,19 @@ async def generate( async def fit_posts( self, - posts: Union[AsyncIterable[Dict[str, Any]], Iterable[Dict[str, Any]]], + posts: Union[ + AsyncIterable[Tuple[Dict[str, Any], float]], + Iterable[Tuple[Dict[str, Any], float]], + ], *args, **kwargs, ) -> None: async with stream.iterate(posts).stream() as posts: async def to_requests(): - async for post in posts: + async for post, score in posts: yield FitPostsRequest( - post=RealPost(content=json.dumps(post)) + post=RealPost(content=json.dumps(post), score=score) ) await self._stub.fit_posts(to_requests(), *args, **kwargs) @@ -384,7 +387,8 @@ async def to_requests(): async def fit_scores( self, scores: Union[ - AsyncIterable[Tuple[UUID, float]], Iterable[Tuple[UUID, float]] + AsyncIterable[Tuple[UUID, float]], + Iterable[Tuple[UUID, float]], ], *args, **kwargs,