Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9655ff4
Add Truncated Response Back in for Another Round
richwardle Apr 30, 2025
52d47ad
Split Out Function
richwardle Apr 30, 2025
5c85139
Precommits
richwardle Apr 30, 2025
e903fcf
Isort
richwardle May 1, 2025
2ff2e95
PC
richwardle May 1, 2025
08d0def
Simplify Calls
richwardle May 1, 2025
149ac00
Remove Retry in function
bkb2135 May 2, 2025
3f4165b
More Tweaks
bkb2135 May 6, 2025
1f8d951
Comment out Unecessary Logs
bkb2135 May 6, 2025
7740f4e
Default to Default
bkb2135 May 6, 2025
4108263
Linting
bkb2135 May 6, 2025
6cf8263
Merge pull request #701 from macrocosm-os/SN1-464-deep-research-trunc…
bkb2135 May 6, 2025
32e3e0b
Use stream for deep researcher
bkb2135 May 6, 2025
9f91ee9
Isort
bkb2135 May 6, 2025
5b44dcd
Add Tools to API
bkb2135 May 6, 2025
245c5ac
Linting
bkb2135 May 6, 2025
00416a6
Reduce wandb size
dbobrenko May 6, 2025
afd31d6
Merge pull request #706 from macrocosm-os/features/support-tool-calling
bkb2135 May 6, 2025
c978e86
Merge pull request #705 from macrocosm-os/fix/use-stream-in-deep
bkb2135 May 6, 2025
a293925
Merge pull request #707 from macrocosm-os/fix/reduce-wandb-size
dbobrenko May 6, 2025
a0e3e1b
Update task_registry.py
bkb2135 May 6, 2025
639082d
Precommits
richwardle May 7, 2025
46a2ed0
If client disconnects score whatever has been sent
bkb2135 May 7, 2025
e6a25c1
Add submit and poll endpoints
gevagorou-mai May 7, 2025
b56fc21
Fix black failure
gevagorou-mai May 7, 2025
70665bc
tidy up
gevagorou-mai May 7, 2025
8b98781
tidy up
gevagorou-mai May 7, 2025
4f062b9
rename endpoint to submit_chain_of_thought_job
gevagorou-mai May 7, 2025
962978c
Use sqlite for job_store.py
gevagorou-mai May 8, 2025
4dde1e3
GEN-1220 add new line for formatting
May 8, 2025
0a5f813
Lint
bkb2135 May 8, 2025
2d85ac7
Lint more
bkb2135 May 8, 2025
d61e981
GEN-1220 pre-commit run to correct formatting
May 8, 2025
d388b7c
Merge pull request #709 from macrocosm-os/SN1-478-re-weight-task-dist…
bkb2135 May 8, 2025
26dd608
Merge pull request #714 from macrocosm-os/gen-1220-use-sqlite-job-store
gevagorou-mai May 8, 2025
601c8b1
Merge pull request #712 from macrocosm-os/gen-1120-submid-poll-endpoints
dbobrenko May 8, 2025
bebbf91
Fix and optimize API, add strict checks to web, etc (#715)
dbobrenko May 8, 2025
cd33ed5
Restore task registry (#716)
dbobrenko May 8, 2025
d4dee30
Merge branch 'staging' into fix/close-func
dbobrenko May 8, 2025
3243619
Merge pull request #710 from macrocosm-os/fix/close-func
bkb2135 May 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions prompting/api/scoring/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,29 +59,22 @@ async def score_response(
):
model = None
payload: dict[str, Any] = await request.json()
logger.debug(f"Awaited body: {payload}")
body = payload.get("body")
timeout = payload.get("timeout", shared_settings.NEURON_TIMEOUT)
uids = payload.get("uids", [])
chunks = payload.get("chunks", {})
chunk_dicts_raw = payload.get("chunk_dicts_raw", {})
timings = payload.get("timings", {})
logger.debug("About to check chunks and uids")
if not uids or not chunks:
logger.error(f"Either uids: {uids} or chunks: {chunks} is not valid, skipping scoring")
return
uids = [int(uid) for uid in uids]
model = body.get("model")
logger.debug("About to check model")
if model and model not in shared_settings.LLM_MODEL:
logger.error(f"Model {model} not available for scoring on this validator.")
return
logger.debug("Model has been checked")
llm_model = ModelZoo.get_model_by_id(model)
logger.debug("Got LLM Model from ModelZoo")
task_name = body.get("task")
logger.debug(f"Task name set: {task_name}")
logger.debug(f"Length pre-insertion: {len(task_scorer.scoring_queue)}")
if task_name == "InferenceTask":
organic_task = InferenceTask(
messages=body.get("messages"),
Expand Down Expand Up @@ -120,6 +113,8 @@ async def score_response(
seed=int(body.get("seed", 0)),
sampling_params=body.get("sampling_params", {}),
query=search_term,
target_results=body.get("target_results", 1),
timeout=body.get("timeout", 10),
),
response=DendriteResponseEvent(
uids=uids,
Expand All @@ -132,5 +127,4 @@ async def score_response(
step=-1,
task_id=str(uuid.uuid4()),
)
logger.debug(f"Current Queue: {len(task_scorer.scoring_queue)}")
logger.info("Organic task appended to scoring queue")
logger.debug(f"Organic queue size: {len(task_scorer.scoring_queue)}")
6 changes: 3 additions & 3 deletions prompting/datasets/random_website.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

class DDGDatasetEntry(DatasetEntry):
search_term: str
website_url: str = None
website_content: str = None
website_url: str | None = None
website_content: str | None = None
query: str | None = None
source: str | None = None

Expand All @@ -40,7 +40,7 @@ def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[

@staticmethod
@lru_cache(maxsize=1000)
def extract_website_content(url: str, retries: int = 3) -> Optional[str]:
def extract_website_content(url: str, retries: int = 3) -> str | None:
exception: Exception | None = None
for _ in range(retries):
try:
Expand Down
4 changes: 3 additions & 1 deletion prompting/llms/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,15 @@ async def generate_logits(
sampling_params: dict[str, float] = None,
seed: int = None,
continue_last_message: bool = False,
top_logprobs: int = 10,
):
model_instance: ReproducibleVLLM = await self.get_model(model)
return await model_instance.generate_logits(
messages=messages,
sampling_params=sampling_params,
seed=seed,
continue_last_message=continue_last_message,
top_logprobs=top_logprobs,
)

async def cleanup(self):
Expand Down Expand Up @@ -262,5 +264,5 @@ async def run_step(self):
await self.llm_model_manager.load_model(selected_model)
except MemoryError as e:
self.memory_error = e
logger.debug(f"Active models: {self.llm_model_manager.active_models.keys()}")
logger.debug(f"Active models: {list(self.llm_model_manager.active_models.keys())}")
await asyncio.sleep(0.01)
59 changes: 38 additions & 21 deletions prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,38 @@ def __init__(
# Store tokenizer from VLLM for consistency
self.tokenizer = self.model.get_tokenizer()

@classmethod
async def get_max_tokens(
cls,
sampling_params: dict[str, str | float | int | bool],
default_value: int = 512,
) -> int:
# Process max tokens with backward compatibility.
max_tokens = sampling_params.get("max_tokens")
if max_tokens is None:
max_tokens = sampling_params.get("max_new_tokens")
if max_tokens is None:
max_tokens = sampling_params.get("max_completion_tokens", default_value)
return max_tokens

@classmethod
async def prepare_sampling_params(
cls, sampling_params: dict[str, str | float | int | bool] | None = None
) -> SamplingParams:
sampling_params = sampling_params or {}
max_tokens = await cls.get_max_tokens(sampling_params)

params = SamplingParams(
temperature=float(sampling_params.get("temperature", 1.0)),
top_p=float(sampling_params.get("top_p", 1.0)),
max_tokens=int(max_tokens),
presence_penalty=float(sampling_params.get("presence_penalty", 0.0)),
frequency_penalty=float(sampling_params.get("frequency_penalty", 0.0)),
top_k=int(sampling_params.get("top_k", -1)),
logprobs=sampling_params.get("logprobs", None),
)
return params

async def generate(
self,
messages: list[str] | list[dict[str, str]],
Expand Down Expand Up @@ -71,24 +103,9 @@ async def generate(
else:
prompt = messages[0] if isinstance(messages, list) else messages

# Convert sampling parameters to VLLM format
# Convert sampling parameters to vLLM format.
params = sampling_params if sampling_params else self.sampling_params

max_tokens = params.get("max_new_tokens")
if max_tokens is None:
max_tokens = params.get("max_tokens", 512)

vllm_params = SamplingParams(
temperature=params.get("temperature", 1.0),
top_p=params.get("top_p", 1.0),
max_tokens=int(max_tokens),
presence_penalty=params.get("presence_penalty", 0.0),
frequency_penalty=params.get("frequency_penalty", 0.0),
top_k=int(params.get("top_k", -1)),
logprobs=params.get("logprobs", None),
)

# Generate using VLLM
vllm_params = await self.prepare_sampling_params(params)
outputs = self.model.generate(prompt, vllm_params)

if not outputs:
Expand All @@ -101,7 +118,7 @@ async def generate(
async def generate_logits(
self,
messages: list[str] | list[dict[str, str]],
top_n: int = 10,
top_logprobs: int = 10,
sampling_params: dict[str, str | float | int | bool] | None = None,
seed: int | None = None,
continue_last_message: bool = False,
Expand All @@ -110,7 +127,7 @@ async def generate_logits(

Args:
messages: Input messages or text.
top_n: Number of top logits to return (default: 10).
top_logprobs: Number of top logits to return (default: 10).
sampling_params: Generation parameters.
seed: Random seed for reproducibility.
continue_last_message: Whether to continue the last message in chat format.
Expand All @@ -122,8 +139,8 @@ async def generate_logits(
params = sampling_params if sampling_params else self.sampling_params
params = params.copy()
params["max_tokens"] = 1
params["logprobs"] = top_n
vllm_params = SamplingParams(**params)
params["logprobs"] = top_logprobs
vllm_params = await self.prepare_sampling_params(params)

prompt = self.tokenizer.apply_chat_template(
conversation=messages,
Expand Down
82 changes: 54 additions & 28 deletions prompting/rewards/exact_match.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from typing import Any

import numpy as np
import torch
Expand All @@ -14,15 +15,17 @@

shared_settings = settings.shared_settings

TOP_LOGPROBS = 10
MIN_VERIFY_TOKENS = 10
MAX_VERIFY_TOKENS = 30
NO_EOS_PENALTY = -0.1
INCORRECT_PENALTY = -1.5
MAX_VERIFY_TOKENS = 51
PARTIAL_PENALTY = -1.0
INCORRECT_PENALTY = -2.0
NOT_ENOUGH_TOKENS_PENALTY_SCALE = 0.1
MIN_SMOOTH_PENALTY_SCALE = 0.6
MIN_SMOOTH_PENALTY_SCALE = 0.3
MIN_TIME_PENALTY_SCALE = 0.3
VERIFICATION_THRESH_CONTAINS = 0.92
VERIFICATION_THRESH_SIM = 0.83
VERIFICATION_SIM_EXP_SCALE = 2.0


class LogitsRewardModel(BaseRewardModel):
Expand All @@ -39,7 +42,7 @@ async def reward( # noqa: C901
raise ValueError("Model manager must be set")

all_chunks: list[list[str]] = response_event.stream_results_all_chunks
all_chunk_dicts_raw: list[list[ChatCompletionChunk]] = response_event.stream_results_all_chunk_dicts_raw
all_chunk_dicts_raw: list[list[ChatCompletionChunk | dict]] = response_event.stream_results_all_chunk_dicts_raw
uids: np.ndarray | list[float] = response_event.uids
all_timings: list[list[float]] = response_event.stream_results_all_chunks_timings
completions: list[str] = response_event.completions
Expand All @@ -59,9 +62,11 @@ async def reward( # noqa: C901
raise ValueError("Timeout must be greater than 0.")

# If max_tokens are not provided, always check for eos.
max_tokens = sampling_parameters.get("max_tokens", 8192)
model = await model_manager.get_model(task.llm_model_id)
max_tokens = await model.get_max_tokens(sampling_parameters, default_value=2048)
eos_token = model.tokenizer.eos_token
bos_token = model.tokenizer.bos_token
special_tokens = set([bos_token, eos_token])
timing_verified: list[list[float]] = []
rewards: list[float] = []
logger.info(f"Verifying logits with model {task.llm_model_id}")
Expand All @@ -70,8 +75,8 @@ async def reward( # noqa: C901
penalty = INCORRECT_PENALTY
reward_scale = 1.0
try:
# If no response is provided, apply full penalty.
if not chunks:
if not chunks or not chunk_dicts_raw:
# If no response is provided, apply full penalty.
rewards.append(INCORRECT_PENALTY)
timing_verified.append([-1.0])
continue
Expand All @@ -83,6 +88,12 @@ async def reward( # noqa: C901
timing_verified.append([-1.0])
continue

if completion_length > max_tokens:
# Sampling params is ignored.
rewards.append(PARTIAL_PENALTY)
timing_verified.append([-1.0])
continue

if completion_length < MIN_VERIFY_TOKENS:
# Not enough tokens to verify, still proceed to verification with scaled reward if checks will pass.
reward_scale = NOT_ENOUGH_TOKENS_PENALTY_SCALE
Expand All @@ -100,32 +111,43 @@ async def reward( # noqa: C901
verification_logits, _ = await model_manager.generate_logits(
model=task.llm_model_id,
messages=messages,
top_logprobs=TOP_LOGPROBS,
sampling_params=sampling_parameters,
continue_last_message=len(to_complete) > 0,
)
if check_idx < eos_idx:
if not chunk_dicts_raw[check_idx].choices[0].logprobs:
raise ValueError("No logprobs provided")
if chunks[check_idx] in special_tokens:
raise ValueError("Special tokens mid-completion")

if chunk_dicts_raw[check_idx].choices[0].logprobs.content is None:
chunk_dict: dict[str, Any] | ChatCompletionChunk = chunk_dicts_raw[check_idx]
if isinstance(chunk_dict, ChatCompletionChunk):
# Convert chunks to unified dict format.
chunk_dict = chunk_dict.model_dump(mode="python")

if chunk_dict.get("choices", [{}])[0].get("logprobs", {}).get("content") is None:
raise ValueError("Logprobs content is empty")

original_logits = {
info.token: info.logprob
for info in chunk_dicts_raw[check_idx].choices[0].logprobs.content[0].top_logprobs
info["token"]: info["logprob"]
for info in chunk_dict["choices"][0]["logprobs"]["content"][0]["top_logprobs"]
}

if len(verification_logits) == TOP_LOGPROBS + 1:
# Sampled logprobs can be +1, remove the lowest value.
del verification_logits[min(verification_logits, key=verification_logits.get)]

logit_sim = self.verify_logit_similarity(original_logits, verification_logits)
scores_sim.append(logit_sim)

logit_contains = self.verify_logit_contains(
chunks[check_idx], original_logits, verification_logits
)

scores_contains.append(logit_contains)

elif check_idx == eos_idx and completion_length < max_tokens:
if eos_token and eos_token not in verification_logits:
penalty = NO_EOS_PENALTY
penalty = PARTIAL_PENALTY
raise ValueError("Partial completion")

score_sim_mean = float(np.mean(scores_sim))
Expand All @@ -138,9 +160,11 @@ async def reward( # noqa: C901
raise ValueError(f"Logits contains mean score is below threshold: {score_contains_mean:.2f}")

timing_verified.append(timings)
smooth_reward = self.smooth_timings_reward(timings)
timingsdt = np.abs(np.diff(timings))
smooth_reward = self.smooth_timings_reward(timingsdt)
# Min-max scale logits reward, e.g from [0.95; 1.0] to [0.0, 1.0].
score_sim_mean = self.rescale(score_sim_mean, min_value=VERIFICATION_THRESH_SIM)
score_sim_mean = score_sim_mean**VERIFICATION_SIM_EXP_SCALE
score_contains_mean = self.rescale(score_contains_mean, min_value=VERIFICATION_THRESH_CONTAINS)
rewards.append(score_sim_mean * score_contains_mean * smooth_reward * reward_scale)
except BaseException as e:
Expand Down Expand Up @@ -205,26 +229,28 @@ def fastest_timing(values: list[list[float]]) -> float:
"""Return the smallest sum of inner list, compute its sum only if the list contains no negative numbers."""
best = float("+inf")
for subset in values:
if subset and min(subset) >= 0.0:
if len(subset) and min(subset) >= 0.0:
subset_sum = sum(subset) / len(subset)
if subset_sum < best:
best = subset_sum
return best if best < float("+inf") else 1e-6

@staticmethod
def smooth_timings_reward(timings_uid: list[float], min_reward: float = MIN_SMOOTH_PENALTY_SCALE) -> float:
"""Return smooth stream ration based on the deviation between chunks timings.

Args:
timings_uid: List of timings for a specific miner.

Returns:
float: Smoothed penalty value.
"""
if not timings_uid:
def smooth_timings_reward(
timings_uid: list[float] | np.ndarray,
tolerance_sec: float = 1,
min_reward: float = MIN_SMOOTH_PENALTY_SCALE,
penalty_strength: float = 5,
) -> float:
"""If delay between chunks is longer than tolerance, apply non-smooth stream penalty."""
if not len(timings_uid):
return 0.0

smooth_penalty = np.std(timings_uid)
max_timing = max(timings_uid)
if max_timing < tolerance_sec:
return 1.0

smooth_penalty = np.std(timings_uid) * penalty_strength
return max(min_reward, 1.0 - smooth_penalty)

@staticmethod
Expand Down Expand Up @@ -252,7 +278,7 @@ def verify_logit_similarity(
if not gt_logits:
return 0.0

if len(candidate_logits) != len(gt_logits):
if len(candidate_logits) != TOP_LOGPROBS:
return 0.0

# Tokens common to both distributions.
Expand Down
13 changes: 8 additions & 5 deletions prompting/rewards/inference_reward_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from loguru import logger

from prompting.rewards.exact_match import LogitsRewardModel
from prompting.rewards.relevance import RelevanceRewardModel
from prompting.rewards.reward import BaseRewardModel, BatchRewardOutput
Expand All @@ -19,9 +21,10 @@ async def reward(
if model_manager is None:
raise ValueError("Model manager must be set")

if not model_id or task.organic:
relevance_reward_model = RelevanceRewardModel()
return await relevance_reward_model.reward(reference, response_event, model_manager=model_manager)
if model_id or task.organic:
logger.info("Using logits reward model")
logits_reward_model = LogitsRewardModel()
return await logits_reward_model.reward(reference, response_event, task, model_manager=model_manager)

logits_reward_model = LogitsRewardModel()
return await logits_reward_model.reward(reference, response_event, task, model_manager=model_manager)
relevance_reward_model = RelevanceRewardModel()
return await relevance_reward_model.reward(reference, response_event, model_manager=model_manager)
Loading