In [None]:
import requests

from typing import List
from caikit.core.data_model.json_dict import JsonDict

In [None]:
model_id = "bge-en-large-caikit"
infer_endpoint = "https://flan-t5-small-caikit-predictor-userx-workshop.apps.<cluster>.com"

In [None]:
def embeddingTask(text:str):
    """
    Takes a single text input and returns its word embeddings
    """
    response = requests.post(
        infer_endpoint + "/api/v1/task/embedding",
        json={
            "model_id": model_id,
            "inputs": text
        },
        verify=False
    )
    if response.status_code == 200:
        inference_response = response.json()['result']['data']['values']
        return inference_response
    print(response.text)


def embeddingTasks(texts: List[str]):
    """
    Takes a list of text inputs and returns their word embeddings
    """
    response = requests.post(
       infer_endpoint + "/api/v1/task/embedding-tasks",
        json={
            "model_id":model_id,
            "inputs":  texts
        },
        verify=False
    )
    if response.status_code == 200:
        inference_response = response.json()['results']['vectors'][0]['data']['values']
        return inference_response
    print(response.text)

def sentenceSimilarityTask(source_sentence: str, sentences: List[str]):
    """
    Calculates sentence similarity  between a single source sentence and reference sentences
    """
    response = requests.post(
        infer_endpoint + "/api/v1/task/sentence-similarity",
        json={
        "model_id":model_id,
        "inputs": {
            "source_sentence": source_sentence,
            "sentences": sentences
            }
        },
        verify=False
    )
    if response.status_code == 200:
        inference_response = response.json()['result']['scores']
        return inference_response
    print(response.text)

def sentenceSimilarityTasks(source_sentences: List[str], sentences: List[str]):
    """
    Calculates sentence similarity between multiple source sentences and reference sentences
    """
    response = requests.post(
       infer_endpoint + "/api/v1/task/sentence-similarity-tasks",
        json={
            "model_id":model_id,
            "inputs": {
                "source_sentences": source_sentences,
                "sentences": sentences
            }
        },
        verify=False
    )
    if response.status_code == 200:
        inference_response = response.json()['results']
        return inference_response
    print(response.text)

def rerankTask(query: str,  documents: List[JsonDict], top_n=None):
        """
        Returns reranking results for a single query and reference documents
        """
        if top_n is None:
            top_n = len(documents)
        response = requests.post(
           infer_endpoint + "/api/v1/task/rerank",
            json={
             "model_id":model_id,
             "inputs": {
                 "documents": documents,
                 "query": query
             },
            "parameters": {
                "top_n": top_n
                }
            },
            verify=False
        )
        if response.status_code == 200:
            inference_response = response.json()['result']
            return inference_response
        print(response.text)

def rerankTasks(queries: List[str], documents: List[JsonDict], top_n=None):
    """
    Returns reranking results for multiple queries and reference documents
    """
    if top_n is None:
        top_n = len(documents)
    response = requests.post(
       infer_endpoint + "/api/v1/task/rerank-tasks",
        json={
            "model_id":model_id,
            "inputs": {
                "documents": documents,
                "queries": queries
            },
        "parameters": {
            "top_n": top_n
            }
        },
        verify=False

    )
    if response.status_code == 200:
        inference_response = response.json()['results']
        return inference_response
    print(response.text)