Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions prompting/weight_setting/weight_setter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from collections import deque
import datetime
import json
from collections import deque
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -42,9 +42,7 @@ async def set_weights(
# If weights will not be set on chain, we should not synchronize.
augmented_weights = weights
else:
augmented_weights = await weight_syncer.get_augmented_weights(
weights=weights, uid=shared_settings.UID
)
augmented_weights = await weight_syncer.get_augmented_weights(weights=weights, uid=shared_settings.UID)
except BaseException as ex:
logger.exception(f"Issue with setting weights: {ex}")
augmented_weights = weights
Expand All @@ -60,8 +58,7 @@ async def set_weights(

# Convert to uint16 weights and uids.
uint_uids, uint_weights = bt.utils.weight_utils.convert_weights_and_uids_for_emit(
uids=processed_weight_uids,
weights=processed_weights
uids=processed_weight_uids, weights=processed_weights
)
except Exception as ex:
logger.exception(f"Skipping weight setting: {ex}")
Expand Down Expand Up @@ -167,9 +164,7 @@ async def _load_rewards(self):
if payload is None:
raise ValueError(f"Malformed weight history file: {data}")

self.reward_history.append(
{int(uid): {"reward": float(reward)} for uid, reward in payload.items()}
)
self.reward_history.append({int(uid): {"reward": float(reward)} for uid, reward in payload.items()})
except BaseException as exc:
self.reward_history: deque[dict[int, dict[str, Any]]] | None = deque(maxlen=self.reward_history_len)
logger.error(f"Couldn't load rewards from file, resetting weight history: {exc}")
Expand Down Expand Up @@ -217,8 +212,7 @@ async def merge_task_rewards(cls, reward_events: list[list[WeightedRewardEvent]]
processed_rewards = task_rewards / max(1, (np.sum(task_rewards[task_rewards > 0]) + 1e-10))
else:
processed_rewards = cls.apply_steepness(
raw_rewards=task_rewards,
steepness=shared_settings.REWARD_STEEPNESS
raw_rewards=task_rewards, steepness=shared_settings.REWARD_STEEPNESS
)
processed_rewards *= task_config.probability

Expand All @@ -238,11 +232,11 @@ def apply_steepness(cls, raw_rewards: npt.NDArray[np.float32], steepness: float
p > 0.5 makes the function more exponential (winner takes all).
"""
# 6.64385619 = ln(100)/ln(2) -> this way if p = 0.5, the exponent is exactly 1.
exponent = (steepness ** 6.64385619) * 100
exponent = (steepness**6.64385619) * 100
raw_rewards = np.array(raw_rewards) / max(1, (np.sum(raw_rewards[raw_rewards > 0]) + 1e-10))
positive_rewards = np.clip(raw_rewards, 1e-10, np.inf)
normalised_rewards = positive_rewards / np.max(positive_rewards)
post_func_rewards = normalised_rewards ** exponent
post_func_rewards = normalised_rewards**exponent
all_rewards = post_func_rewards / (np.sum(post_func_rewards) + 1e-10)
all_rewards[raw_rewards <= 0] = raw_rewards[raw_rewards <= 0]
return all_rewards
Expand All @@ -251,13 +245,13 @@ async def run_step(self):
await asyncio.sleep(0.01)
try:
if self.reward_events is None:
logger.error(f"No rewards events were found, skipping weight setting")
logger.error("No rewards events were found, skipping weight setting")
return

final_rewards = await self.merge_task_rewards(self.reward_events)

if final_rewards is None:
logger.error(f"No rewards were found, skipping weight setting")
logger.error("No rewards were found, skipping weight setting")
return

await self._save_rewards(final_rewards)
Expand Down