Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prompting/llms/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ async def _vram_cleanup(self):
torch.cuda.reset_accumulated_memory_stats()
await asyncio.sleep(1.0)
except BaseException as e:
logger.error(f"Error during CUDA empty cache: {e}")
logger.warning(f"Error during CUDA empty cache: {e}")
else:
logger.warning("CUDA is not available")

Expand Down
13 changes: 8 additions & 5 deletions prompting/rewards/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def reward(

all_chunks: list[list[str]] = response_event.stream_results_all_chunks
all_chunk_dicts_raw: list[list[ChatCompletionChunk]] = response_event.stream_results_all_chunk_dicts_raw
uids: np.ndarray | list[float] = response_event.uids
all_timings: list[list[float]] = response_event.stream_results_all_chunks_timings
completions: list[str] = response_event.completions
timeout: float = response_event.timeout
Expand Down Expand Up @@ -118,7 +119,7 @@ async def reward(

# Iterate over each response event

for chunks, timings, chunk_dicts_raw in zip(all_chunks, all_timings, all_chunk_dicts_raw):
for chunks, timings, chunk_dicts_raw, uid in zip(all_chunks, all_timings, all_chunk_dicts_raw, uids):
try:
# If no response is provided, apply full penalty
if not chunks:
Expand All @@ -133,19 +134,21 @@ async def reward(
for idx in verify_indices:
check_idx = min(idx, completion_length - 1)
if not chunk_dicts_raw[check_idx].choices[0].logprobs:
logger.debug(f"Miner {uid} failed to provide logprobs: {chunk_dicts_raw[check_idx]}")
verification_scores.append(0.0)
continue

if chunk_dicts_raw[check_idx].choices[0].logprobs.content is None:
logger.debug(f"Miner failed to provide logits: {chunk_dicts_raw[check_idx]}")
logger.debug(f"Miner {uid} failed to provide logprobs content: {chunk_dicts_raw[check_idx]}")
verification_scores.append(0.0)
continue

original_logits = {
info.token: info.logprob
for info in chunk_dicts_raw[check_idx].choices[0].logprobs.content[0].top_logprobs
}

verification_output, prompt = await self.model_manager.generate_logits(
verification_output, prompt = await model_manager.generate_logits(
model=task.llm_model_id,
messages=task.task_messages + [{"role": "assistant", "content": "".join(chunks[:check_idx])}],
sampling_params=sampling_parameters,
Expand All @@ -169,8 +172,8 @@ async def reward(
rewards.append(float(final_score > VERIFICATION_THRESHOLD) * timing_reward)
timing_outputs.append(np.array(valid_chunks).mean())
except Exception as e:
logger.warning(f"Error in reward calculation: {e}")
rewards.append(-INCORRECT_PENALTY)
logger.debug(f"Miner {uid} failed to provide logits chunk, setting reward to 0: {e}")
rewards.append(0.0)
timing_outputs.append(0.0)

reward_output = BatchRewardOutput(
Expand Down
4 changes: 0 additions & 4 deletions prompting/rewards/inference_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@ async def reward(
**kwargs,
) -> BatchRewardOutput:
"""Gives an exact reward of 1 if the response matches the reference, 0 otherwise"""
# Use self.model_manager if model_manager is None
model_manager = model_manager or self.model_manager
if model_manager is None:
raise ValueError("Model manager must be set")

if model_id:
logits_reward_model = LogitsRewardModel()
logits_reward_model.model_manager = model_manager
return await logits_reward_model.reward(reference, response_event, task, model_manager=model_manager)

relevance_reward_model = RelevanceRewardModel()
relevance_reward_model.model_manager = model_manager
return await relevance_reward_model.reward(reference, response_event, model_manager=model_manager)
16 changes: 8 additions & 8 deletions prompting/rewards/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ def rewards_normalized(self) -> np.ndarray:


class BaseRewardModel(ABC, BaseModel):
model_manager: ModelManager = None
weight: float = 1.0

@abstractmethod
async def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs) -> BatchRewardOutput:
async def reward(
self, reference: str, response_event: DendriteResponseEvent, model_manager: ModelManager = None, **kwargs
) -> BatchRewardOutput:
raise NotImplementedError("You must implement the reward method")

async def apply(
Expand All @@ -81,11 +82,14 @@ async def apply(
challenge: str | None = None,
reward_type: Literal["reward", "penalty"] = "reward",
task: BaseTextTask | None = None,
model_manager: ModelManager | None = None,
**kwargs,
) -> WeightedRewardEvent:
t0 = time.time()
comparator = reference if reward_type == "reward" else challenge
batch_rewards_output: BatchRewardOutput = await self.reward(comparator, response_event, task=task, **kwargs)
batch_rewards_output: BatchRewardOutput = await self.reward(
comparator, response_event, task=task, model_manager=model_manager, **kwargs
)
batch_rewards_time = time.time() - t0

return WeightedRewardEvent(
Expand Down Expand Up @@ -123,7 +127,6 @@ class BaseRewardConfig(ABC, BaseModel):
and weight it with <1.
"""

model_manager: ModelManager = None
reward_definitions: ClassVar[list[BaseRewardModel]]
penalty_definitions: ClassVar[list[BaseRewardModel]] = []

Expand All @@ -150,10 +153,6 @@ async def apply(
) -> list[WeightedRewardEvent]:
reward_events = []
for weighted_reward in cls.reward_definitions:
# Set the model_manager on the weighted_reward if it's None
if weighted_reward.model_manager is None and model_manager is not None:
weighted_reward.model_manager = model_manager

reward_events.append(
await weighted_reward.apply(
reference=reference,
Expand All @@ -162,6 +161,7 @@ async def apply(
reward_type="reward",
model_id=model_id,
task=task,
model_manager=model_manager,
),
)
return reward_events