From 6bf700959027daeac33cf996ca2616e6e43fea28 Mon Sep 17 00:00:00 2001 From: Sebastian Pietras <01133337@pw.edu.pl> Date: Thu, 25 Aug 2022 08:14:28 +0200 Subject: [PATCH] Added score to `fit_posts` (#13) --- .../src/kilroy_module_client_py_sdk/client.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/kilroy_module_client_py_sdk/src/kilroy_module_client_py_sdk/client.py b/kilroy_module_client_py_sdk/src/kilroy_module_client_py_sdk/client.py index c7f752a..3881503 100644 --- a/kilroy_module_client_py_sdk/src/kilroy_module_client_py_sdk/client.py +++ b/kilroy_module_client_py_sdk/src/kilroy_module_client_py_sdk/client.py @@ -367,16 +367,19 @@ async def generate( async def fit_posts( self, - posts: Union[AsyncIterable[Dict[str, Any]], Iterable[Dict[str, Any]]], + posts: Union[ + AsyncIterable[Tuple[Dict[str, Any], float]], + Iterable[Tuple[Dict[str, Any], float]], + ], *args, **kwargs, ) -> None: async with stream.iterate(posts).stream() as posts: async def to_requests(): - async for post in posts: + async for post, score in posts: yield FitPostsRequest( - post=RealPost(content=json.dumps(post)) + post=RealPost(content=json.dumps(post), score=score) ) await self._stub.fit_posts(to_requests(), *args, **kwargs) @@ -384,7 +387,8 @@ async def to_requests(): async def fit_scores( self, scores: Union[ - AsyncIterable[Tuple[UUID, float]], Iterable[Tuple[UUID, float]] + AsyncIterable[Tuple[UUID, float]], + Iterable[Tuple[UUID, float]], ], *args, **kwargs,