Skip to content

Commit

Permalink
Merge pull request #20 from jina-ai/feat-rerank
Browse files Browse the repository at this point in the history
  • Loading branch information
zac-li authored Feb 26, 2024
2 parents 3d09eff + 644d640 commit 33d6f30
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions jina_sagemaker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def create_transform_job(
job_name = transformer.latest_transform_job.name
return job_name

def _invoke_endpoint(self, texts: Union[str, List[str]]):
def embed(self, texts: Union[str, List[str]]):
if self._endpoint_name is None:
raise Exception(
"No endpoint connected. " "Run connect_to_endpoint() first."
Expand All @@ -219,14 +219,34 @@ def _invoke_endpoint(self, texts: Union[str, List[str]]):
else:
data = json.dumps({"data": [{"text": text} for text in texts]})

return self._sm_runtime_client.invoke_endpoint(
response = self._sm_runtime_client.invoke_endpoint(
EndpointName=self._endpoint_name,
ContentType="application/json",
Body=data,
)

resp = json.loads(response["Body"].read().decode())
return resp["data"]

def rerank(self, documents: List[str], query: str):
if self._endpoint_name is None:
raise Exception(
"No endpoint connected. " "Run connect_to_endpoint() first."
)

data = json.dumps(
{
"data": {"documents": [{"text": document} for document in documents]},
"query": query,
}
)

response = self._sm_runtime_client.invoke_endpoint(
EndpointName=self._endpoint_name,
ContentType="application/json",
Body=data,
)

def embed(self, texts: Union[str, List[str]]):
response = self._invoke_endpoint(texts)
resp = json.loads(response["Body"].read().decode())
return resp["data"]

Expand Down

0 comments on commit 33d6f30

Please sign in to comment.