Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions prompting/rewards/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,17 @@
from typing import Optional

import numpy as np
import requests
from pydantic import ConfigDict
from scipy import spatial

from prompting.rewards.reward import BaseRewardModel, BatchRewardOutput
from shared import constants, settings
from shared import settings
from shared.dendrite import DendriteResponseEvent
from shared.docker_utils import get_embeddings

shared_settings = settings.shared_settings


def get_embeddings(inputs):
"""
Sends a POST request to the local embeddings endpoint and returns the response.

Args:
inputs (str or list of str): A single input string or a list of input strings to embed.

Returns:
dict: JSON response from the embeddings server.
"""
if isinstance(inputs, str):
inputs = [inputs] # convert single string to list

url = f"{constants.DOCKER_BASE_URL}/v1/embeddings"
headers = {"Content-Type": "application/json"}
payload = {"input": inputs}

try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
return {"error": str(e)}


class RelevanceRewardModel(BaseRewardModel):
threshold: Optional[float] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
Expand Down
29 changes: 1 addition & 28 deletions prompting/rewards/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy as np
import pandas as pd
import requests
import whois
from loguru import logger
from pydantic import BaseModel
Expand All @@ -19,36 +18,10 @@
from prompting.rewards.relevance import RelevanceRewardModel
from prompting.rewards.reward import BatchRewardOutput
from prompting.tasks.base_task import BaseTextTask
from shared import constants
from shared.dendrite import DendriteResponseEvent
from shared.docker_utils import get_embeddings
from shared.misc import async_lru_cache


def get_embeddings(inputs):
"""
Sends a POST request to the local embeddings endpoint and returns the response.

Args:
inputs (str or list of str): A single input string or a list of input strings to embed.

Returns:
dict: JSON response from the embeddings server.
"""
if isinstance(inputs, str):
inputs = [inputs] # convert single string to list

url = f"{constants.DOCKER_BASE_URL}/v1/embeddings"
headers = {"Content-Type": "application/json"}
payload = {"input": inputs}

try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
return {"error": str(e)}


MIN_RELEVANT_CHARS = 300
MIN_MATCH_THRESHOLD = 98

Expand Down
29 changes: 29 additions & 0 deletions shared/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,32 @@ async def get_logits(
except requests.exceptions.JSONDecodeError:
logger.error(f"Error generating logits. Status: {response.status_code}, Body: {response.text}")
return ""


def get_embeddings(inputs):
"""
Sends a POST request to the local embeddings endpoint and returns the response.

Args:
inputs (str or list of str): A single input string or a list of input strings to embed.

Returns:
dict: JSON response from the embeddings server.
"""
if isinstance(inputs, str):
inputs = [inputs] # convert single string to list

url = f"{constants.DOCKER_BASE_URL}/v1/embeddings"
headers = {"Content-Type": "application/json"}
payload = {"input": inputs}

try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
return {"error": str(e)}


if __name__ == "__main__":
print(get_embeddings("Hello, world!"))
Loading