Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7a06030
Adding draft for time penalty
Hollyqui Jan 20, 2025
5f314d1
add debug statement
richwardle Jan 21, 2025
6f0010b
WIP: not matching correct generations with reference answers
bkb2135 Jan 30, 2025
ddbbfe3
fix rewarding errors
bkb2135 Jan 30, 2025
fa8c3a9
Fixing formatting
Hollyqui Jan 30, 2025
62c87a6
Update api_keys.json
bkb2135 Jan 30, 2025
6d03791
Update validator.py
bkb2135 Jan 30, 2025
329d165
Update task_registry.py
bkb2135 Jan 30, 2025
d7d2045
Update settings.py
bkb2135 Jan 30, 2025
45b6a8e
fix 0 temperature issue
richwardle Jan 30, 2025
0fc8b30
fix random temp issues
richwardle Jan 30, 2025
cc0eac6
fix the reference mix-match
richwardle Jan 31, 2025
99aaf23
precommit fix
richwardle Jan 31, 2025
c7600e2
appeasing the precommit
richwardle Jan 31, 2025
110cefa
WIP: just needs np formatting
richwardle Feb 1, 2025
8c36655
Cast Final Score as a Float
bkb2135 Feb 3, 2025
864ce5f
Linting
richwardle Feb 3, 2025
54c6c0c
lesser penalty if the miners response is correct but truncated
richwardle Feb 4, 2025
1cbb336
add reduced penelty for miners whose response is truncated
richwardle Feb 4, 2025
6e693b9
Extracted Normalize Function and Cleaned Up ExactMatchRewardModel
richwardle Feb 4, 2025
efc7276
Fixing bug in reward function
richwardle Feb 4, 2025
c45f877
Combine ExactMatch Updates With Reference Generation Issues
richwardle Feb 4, 2025
c8247f6
Normalising chunk timings based on last miner's timing
richwardle Feb 4, 2025
358c93a
Remove Testing Logs
richwardle Feb 4, 2025
e233cab
Merge branch 'staging' into hackathon/add-time-penalty
richwardle Feb 4, 2025
3a22d78
Precommit Fixes
richwardle Feb 4, 2025
5afc328
Linting
richwardle Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api_keys.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{}
{}
93 changes: 83 additions & 10 deletions prompting/rewards/exact_match.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,97 @@
import numpy as np
from loguru import logger

from prompting.rewards.reward import BaseRewardModel, BatchRewardOutput
from shared.dendrite import DendriteResponseEvent
from shared.settings import shared_settings

PENALTY_FACTOR = 0
INCORRECT_PENALTY = 3
INCOMPLETE_PENALTY = 1


def normalize_timing(timing: float, timings: float) -> float:
"""
Normalize the timing so that a lower timing (i.e. faster response) is closer to 1.
Ensures the normalized value is between 0 and 1.
"""

flat_values = [
x
for sublist in timings
if sublist is not None
for x in (sublist if isinstance(sublist, list) else [sublist])
if x is not None
]
last_chunk = max(flat_values) if flat_values else shared_settings.INFERENCE_TIMEOUT
return min(1, max(0, (last_chunk - timing) / last_chunk))


class ExactMatchRewardModel(BaseRewardModel):
def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs) -> BatchRewardOutput:
"""Gives an exact reward of 1 if the response matches the reference, 0 otherwise"""
rewards = []
"""
Calculates rewards based on an exact match of the response with the reference string.

If the response exactly matches the reference, rewards are computed from the normalized timings.
If the response is only a prefix of the reference, a less severe penalty is applied.
Otherwise, a full penalty is given.

Rewards are in the range [-3, 1].

Parameters:
reference (str): The expected response string.
response_event (DendriteResponseEvent): Contains completions, chunked results, timings, etc.

Returns:
BatchRewardOutput: Contains the computed rewards and average timings.
"""
all_chunks: list[list[str]] = response_event.stream_results_all_chunks
all_timings: list[list[float]] = response_event.stream_results_all_chunks_timings
completions: list[str] = response_event.completions
timings = [0] * len(completions)
timeout: float = response_event.timeout

if timeout <= 0:
logger.error("Timeout must be greater than 0. Received timeout: {}", timeout)
raise ValueError("Timeout must be greater than 0.")

for completion in completions:
rewards.append(1 if reference == completion else -PENALTY_FACTOR)
timing_outputs, rewards = [], []

output = BatchRewardOutput(
# Iterate over each response event.
for chunks, timings, completion in zip(all_chunks, all_timings, completions):
# If no response is provided, apply full penalty.
if chunks == []:
rewards.append(-INCORRECT_PENALTY)
timing_outputs.append(0.0)
continue

# If the completion is a prefix of the reference, give a less severe penalty
if len(completion) < len(reference) and reference.startswith(completion):
rewards.append(-INCOMPLETE_PENALTY)
timing_outputs.append(0.0)
continue

# If the completion does not exactly match the reference, apply full penalty.
if reference != completion:
rewards.append(-INCORRECT_PENALTY)
timing_outputs.append(0.0)
continue

# Compute normalized timings for non-empty chunks.
valid_chunks = []
for chunk, timing in zip(chunks, timings):
if chunk:
valid_chunks.append(normalize_timing(timing, all_timings))

# Compute average timings for normalized chunk timings.
if valid_chunks:
# If there are valid chunks, compute the average timing.
final_score = np.mean(valid_chunks)
else:
final_score = -INCORRECT_PENALTY

rewards.append(float(final_score))
timing_outputs.append(np.array(valid_chunks).mean())

return BatchRewardOutput(
rewards=np.array(rewards),
timings=np.array(timings),
timings=np.array(timing_outputs),
)

return output
2 changes: 1 addition & 1 deletion prompting/rewards/penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def name(self) -> str:
return "penalty"

def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs) -> BatchRewardOutput:
"""Compute difference scores given a completion and reference pair."""
"""Penalises miner if they do not respond."""
rewards = []
timings = []
completions: list[str] = response_event.completions
Expand Down
10 changes: 7 additions & 3 deletions prompting/rewards/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import ClassVar, Literal

import numpy as np
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator

from prompting.tasks.base_task import BaseTextTask
from shared.dendrite import DendriteResponseEvent
Expand Down Expand Up @@ -50,12 +50,16 @@ class BatchRewardOutput(BaseModel):
extra_info: dict = {}
model_config = ConfigDict(arbitrary_types_allowed=True)

@model_validator(mode="after")
def validate_rewards_and_timings(cls, v):
if v.rewards.shape != v.timings.shape:
raise ValueError(f"rewards.shape {v.rewards.shape} != timings.shape {v.timings.shape}")
return v

@property
def rewards_normalized(self) -> np.ndarray:
if self.rewards.size == 0:
return np.array([])
if self.rewards.shape != self.timings.shape:
raise ValueError(f"rewards.shape {self.rewards.shape} != timings.shape {self.timings.shape}")
if self.rewards.min() == self.rewards.max():
return np.array([1 / len(self.rewards)] * len(self.rewards))
return (self.rewards - self.rewards.min()) / (self.rewards.max() - self.rewards.min())
Expand Down
6 changes: 3 additions & 3 deletions prompting/rewards/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ async def run_step(self) -> RewardLoggingEvent:
task=scoring_config.task,
)
self.reward_events.append(reward_events)
logger.debug(
f"REFERENCE: {scoring_config.task.reference}\n\n||||RESPONSES: {scoring_config.response.completions}"
)
# logger.debug(
# f"REFERENCE: {scoring_config.task.reference}\n\n||||RESPONSES: {scoring_config.response.completions}"
# )
logger.debug(
f"SCORING: Scored {scoring_config.task.__class__.__name__} {scoring_config.task.task_id} with model {scoring_config.task.llm_model_id} with reward"
)
Expand Down
1 change: 0 additions & 1 deletion prompting/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class BaseTextTask(BaseTask):
reference_system_prompt: ClassVar[str | None] = None
augmentation_system_prompt: ClassVar[str | None] = None
dataset_entry: DatasetEntry | None = None
task_id: str = str(uuid4())
sampling_params: dict[str, float] = shared_settings.SAMPLING_PARAMS

@model_validator(mode="after")
Expand Down
19 changes: 11 additions & 8 deletions prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import ClassVar

import numpy as np
from loguru import logger
from pydantic import Field, model_validator

from prompting.datasets.sn13 import ChatEntry
Expand Down Expand Up @@ -48,18 +49,14 @@ class InferenceTask(BaseTextTask):
llm_model_id: ModelConfig | None = random.choice(ModelZoo.models_configs).llm_model_id
seed: int = Field(default_factory=lambda: random.randint(0, 1_000_000), allow_mutation=False)
sampling_params: dict[str, float] = shared_settings.SAMPLING_PARAMS.copy()
messages: list[dict] | None = None

@model_validator(mode="after")
def random_llm_model_id(self):
if self.query: # If we are already defining query, as in the case of organics, we also specify model.
return self
# Choose system prompt and randomize inference settings
self.system_prompt = random.choice(SYSTEM_PROMPTS)
self.messages = []
if self.system_prompt:
self.messages.append({"role": "system", "content": self.system_prompt})
self.sampling_params["temperature"] = random.randint(0, 10) / 10
self.sampling_params["max_new_tokens"] = random.choice([256, 512, 1024, 2048])
# self.sampling_params["temperature"] = random.randint(1, 10) / 10
# self.sampling_params["max_new_tokens"] = random.choice([256, 512, 1024, 2048])

if np.random.rand() < 0.2:
self.llm_model_id = None
Expand All @@ -70,12 +67,18 @@ def random_llm_model_id(self):
def make_query(self, dataset_entry: ChatEntry) -> str:
if self.query:
return self.query
self.messages.extend(dataset_entry.messages)
system_prompt = random.choice(SYSTEM_PROMPTS)
system_prompt = [{"role": "system", "content": system_prompt}] if system_prompt else []
self.messages = system_prompt + dataset_entry.messages
self.query = self.messages

return self.query

def make_reference(self, dataset_entry: ChatEntry) -> str:
logger.info(f"GENERATING REFERENCE FOR TASK {self.task_id}")
logger.info(f"MODEL: {self.llm_model}")
logger.info(f"SAMPLING PARAMS: {self.sampling_params}")
logger.info(f"MESSAGES: {dataset_entry.messages}")
self.reference = model_manager.generate(
messages=self.messages,
model=self.llm_model,
Expand Down
2 changes: 2 additions & 0 deletions prompting/tasks/task_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def collect_responses(task: BaseTextTask) -> DendriteResponseEvent | None:
],
}

logger.info(f"🔍 SENDING TASK {task.task_id} WITH BODY: {body}")
stream_results = await query_miners(uids, body)
logger.debug(f"🔍 Collected responses from {len(stream_results)} miners")

Expand All @@ -79,6 +80,7 @@ async def collect_responses(task: BaseTextTask) -> DendriteResponseEvent | None:
axons=[
shared_settings.METAGRAPH.axons[x].ip + ":" + str(shared_settings.METAGRAPH.axons[x].port) for x in uids
],
# TODO: I think we calculate the timeout dynamically, so this is likely wrong
timeout=(
shared_settings.INFERENCE_TIMEOUT if isinstance(task, InferenceTask) else shared_settings.NEURON_TIMEOUT
),
Expand Down
4 changes: 2 additions & 2 deletions shared/epistula.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ async def make_openai_query(
),
)
extra_body = {k: v for k, v in body.items() if k not in ["messages", "model"]}
start_time = time.perf_counter()
chat = await miner.chat.completions.create(
model=None,
messages=body["messages"],
Expand All @@ -229,7 +230,6 @@ async def make_openai_query(
choices = []
chunks = []
chunk_timings = []
start_time = time.time()
async for chunk in chat:
if not chunk.choices:
continue
Expand All @@ -240,7 +240,7 @@ async def make_openai_query(
choices[i] += choice.delta.content
if chunk.choices[0].delta.content:
chunks.append(chunk.choices[0].delta.content)
chunk_timings.append(time.time() - start_time)
chunk_timings.append(time.perf_counter() - start_time)
choices = [
Choice(index=i, message=ChatCompletionMessage(content=choice, role="assistant"), finish_reason="stop")
for i, choice in enumerate(choices)
Expand Down