Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,194 changes: 457 additions & 737 deletions poetry.lock

Large diffs are not rendered by default.

23 changes: 13 additions & 10 deletions prompting/api/scoring/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ async def score_response(request: Request, api_key_data: dict = Depends(validate
model = None
payload: dict[str, Any] = await request.json()
body = payload.get("body")
uid = int(payload.get("uid"))
chunks = payload.get("chunks")
timeout = payload.get("timeout", shared_settings.NEURON_TIMEOUT)
uids = payload.get("uid", [])
chunks = payload.get("chunks", {})
if not uids or not chunks:
logger.error(f"Either uids: {uids} or chunks: {chunks} is not valid, skipping scoring")
return
uids = [int(uid) for uid in uids]
model = body.get("model")
if model:
try:
Expand All @@ -50,17 +55,15 @@ async def score_response(request: Request, api_key_data: dict = Depends(validate
llm_model_id=body.get("model"),
seed=int(body.get("seed", 0)),
sampling_params=body.get("sampling_parameters", shared_settings.SAMPLING_PARAMS),
query=body.get("messages")[0]["content"],
query=body.get("messages"),
)
logger.info(f"Task created: {organic_task}")
task_scorer.add_to_queue(
task=organic_task,
response=DendriteResponseEvent(
uids=[uid],
stream_results=[
SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None])
],
timeout=shared_settings.NEURON_TIMEOUT,
uids=uids,
stream_results=[SynapseStreamResult(accumulated_chunks=chunks.get(str(uid), None)) for uid in uids],
timeout=timeout,
),
dataset_entry=DatasetEntry(),
block=shared_settings.METAGRAPH.block,
Expand All @@ -83,11 +86,11 @@ async def score_response(request: Request, api_key_data: dict = Depends(validate
query=search_term,
),
response=DendriteResponseEvent(
uids=[uid],
uids=[uids],
stream_results=[
SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None])
],
timeout=shared_settings.NEURON_TIMEOUT,
timeout=shared_settings.NEURON_TIMEOUT, # TODO: Change this to read from the body
),
dataset_entry=DDGDatasetEntry(search_term=search_term),
block=shared_settings.METAGRAPH.block,
Expand Down
4 changes: 2 additions & 2 deletions prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class InferenceRewardConfig(BaseRewardConfig):
class InferenceTask(BaseTextTask):
name: ClassVar[str] = "inference"
# TODO: Once we want to enable the 'actual' inference task with exact models
query: str | None = None
query: str | list = []
reference: str | None = None
system_prompt: str | None = None
llm_model: ModelConfig | None = None
Expand Down Expand Up @@ -77,7 +77,7 @@ def make_query(self, dataset_entry: ChatEntry) -> str:

def make_reference(self, dataset_entry: ChatEntry) -> str:
self.reference = model_manager.generate(
messages=dataset_entry.messages,
messages=self.messages,
model=self.llm_model,
seed=self.seed,
sampling_params=self.sampling_params,
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,12 @@ trafilatura = { version = ">=1.12.1", optional = true }
datasets = { version = ">=3.1.0", optional = true }
primp = "^0.10.0"
nltk = { version = ">=3.8.1", optional = true }
wandb = { version = ">=0.19.4", optional = true }

[tool.poetry.extras]
validator = [
"torch",
"tiktoken",
"transformers",
"torchvision",
"accelerate",
Expand All @@ -182,6 +184,7 @@ validator = [
"trafilatura",
"datasets",
"nltk",
"wandb",
]

[build-system]
Expand Down
14 changes: 4 additions & 10 deletions validator_api/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ async def stream_from_first_response(

# Continue collecting remaining responses in background for scoring
remaining = asyncio.gather(*pending, return_exceptions=True)
asyncio.create_task(collect_remaining_responses(remaining, collected_chunks_list, body, uids))
remaining_tasks = asyncio.create_task(collect_remaining_responses(remaining, collected_chunks_list, body, uids))
await remaining_tasks
asyncio.create_task(forward_response(uids, body, collected_chunks_list))

except asyncio.CancelledError:
logger.info("Client disconnected, streaming cancelled")
Expand Down Expand Up @@ -171,10 +173,7 @@ async def collect_remaining_responses(
content = getattr(chunk.choices[0].delta, "content", None)
if content is None:
continue
collected_chunks_list[0].append(content)
for uid, chunks in zip(uids, collected_chunks_list):
# Forward for scoring
asyncio.create_task(forward_response(uid, body, chunks))
collected_chunks_list[i + 1].append(content)

except Exception as e:
logger.exception(f"Error collecting remaining responses: {e}")
Expand Down Expand Up @@ -262,9 +261,4 @@ async def chat_completion(
if first_valid_response is None:
raise HTTPException(status_code=502, detail="No valid response received")

# Forward all collected responses for scoring in the background
for i, response in enumerate(collected_responses):
if response and isinstance(response, tuple):
asyncio.create_task(forward_response(uid=selected_uids[i], body=body, chunks=response[1]))

return first_valid_response[0] # Return only the response object, not the chunks
11 changes: 6 additions & 5 deletions validator_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

# TODO: Modify this so that all the forwarded responses are sent in a single request. This is both more efficient but
# also means that on the validator side all responses are scored at once, speeding up the scoring process.
async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
uid = int(uid)
logger.info(f"Forwarding response from uid {uid} to scoring with body: {body} and chunks: {chunks}")
async def forward_response(uids: int, body: dict[str, any], chunks: list[str]):
uids = [int(u) for u in uids]
chunk_dict = {u: c for u, c in zip(uids, chunks)}
logger.info(f"Forwarding response from uid {uids} to scoring with body: {body} and chunks: {chunks}")
if not shared_settings.SCORE_ORGANICS:
return

Expand All @@ -17,7 +18,7 @@ async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
return

url = f"http://{shared_settings.VALIDATOR_API}/scoring"
payload = {"body": body, "chunks": chunks, "uid": uid}
payload = {"body": body, "chunks": chunk_dict, "uid": uids}
try:
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
Expand All @@ -28,7 +29,7 @@ async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
logger.info(f"Forwarding response completed with status {response.status_code}")
else:
logger.exception(
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}"
f"Forwarding response uid {uids} failed with status {response.status_code} and payload {payload}"
)
except Exception as e:
logger.error(f"Tried to forward response to {url} with payload {payload}")
Expand Down
Loading