Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions prompting/api/weight_syncing/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import time

from fastapi import APIRouter, Depends, HTTPException, Request
Expand All @@ -14,25 +15,34 @@ def get_weight_dict(request: Request):
return request.app.state.weight_dict


def get_uid_from_hotkey(hotkey: str):
def get_uid_from_hotkey(hotkey: str) -> int:
return shared_settings.METAGRAPH.hotkeys.index(hotkey)


async def verify_weight_signature(request: Request):
signed_by = request.headers.get("Epistula-Signed-By")
signed_for = request.headers.get("Epistula-Signed-For")
if not signed_by or not signed_for:
raise HTTPException(400, "Missing Epistula-Signed-* headers")

if signed_for != shared_settings.WALLET.hotkey.ss58_address:
logger.error("Bad Request, message is not intended for self")
raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self")
validator_hotkeys = [shared_settings.METAGRAPH.hotkeys[uid] for uid in WHITELISTED_VALIDATORS_UIDS]
if signed_by not in validator_hotkeys:
logger.error(f"Signer not the expected ss58 address: {signed_by}")
raise HTTPException(status_code=401, detail="Signer not the expected ss58 address")

now = time.time()
body = await request.body()
if body["uid"] != get_uid_from_hotkey(signed_by):
logger.error("Invalid uid")
raise HTTPException(status_code=400, detail="Invalid uid in body")
body: bytes = await request.body()
try:
payload = json.loads(body)
except json.JSONDecodeError:
raise HTTPException(400, "Invalid JSON body")

if payload.get("uid") != get_uid_from_hotkey(signed_by):
raise HTTPException(400, "Invalid uid in body")

err = verify_signature(
request.headers.get("Epistula-Request-Signature"),
body,
Expand Down
6 changes: 0 additions & 6 deletions prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,12 @@ def __init__(
self.model_id = model_id
self.sampling_params = {} if sampling_params else sampling_params

# VLLM specific initialization
# gpu_memory_utilization = 0.9 # Default high utilization since VLLM is memory efficient
self.model = LLM(
model=model_id,
# tensor_parallel_size=1, # Single GPU by default
# dtype="float16",
trust_remote_code=True,
gpu_memory_utilization=0.9,
max_model_len=8192,
)

# Store tokenizer from VLLM for consistency
self.tokenizer = self.model.get_tokenizer()

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "prompting"
version = "2.19.6"
version = "2.19.7"
description = "Subnetwork 1 runs on Bittensor and is maintained by Macrocosmos. It's an effort to create decentralised AI"
authors = ["Kalei Brady, Dmytro Bobrenko, Felix Quinque, Steffen Cruz, Richard Wardle"]
readme = "README.md"
Expand Down
25 changes: 12 additions & 13 deletions shared/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ async def get_generation(
return ""


# @async_lru_cache(maxsize=1000)
async def get_logits(
messages: list[str],
model: None = None,
Expand All @@ -41,21 +40,21 @@ async def get_logits(
continue_last_message: bool = False,
top_logprobs: int = 10,
) -> dict[str, Any] | None:
url = f"{constants.DOCKER_BASE_URL}/v1/chat/generate_logits"
headers = {"Content-Type": "application/json"}
payload = {
"messages": messages,
"seed": seed,
"sampling_params": sampling_params,
"top_logprobs": top_logprobs,
"continue_last_message": continue_last_message,
}
response = requests.post(url, headers=headers, json=payload)
try:
url = f"{constants.DOCKER_BASE_URL}/v1/chat/generate_logits"
headers = {"Content-Type": "application/json"}
payload = {
"messages": messages,
"seed": seed,
"sampling_params": sampling_params,
"top_logprobs": top_logprobs,
"continue_last_message": continue_last_message,
}
response = requests.post(url, headers=headers, json=payload)
json_response = response.json()
return json_response
except requests.exceptions.JSONDecodeError:
logger.error(f"Error generating logits. Status: {response.status_code}, Body: {response.text}")
except BaseException as exc:
logger.error(f"Error generating logits: {exc}")
return None


Expand Down