diff --git a/prompting/weight_setting/weight_setter.py b/prompting/weight_setting/weight_setter.py index 2271c26ea..56df4c1e3 100644 --- a/prompting/weight_setting/weight_setter.py +++ b/prompting/weight_setting/weight_setter.py @@ -1,7 +1,7 @@ import asyncio -from collections import deque import datetime import json +from collections import deque from pathlib import Path from typing import Any @@ -42,9 +42,7 @@ async def set_weights( # If weights will not be set on chain, we should not synchronize. augmented_weights = weights else: - augmented_weights = await weight_syncer.get_augmented_weights( - weights=weights, uid=shared_settings.UID - ) + augmented_weights = await weight_syncer.get_augmented_weights(weights=weights, uid=shared_settings.UID) except BaseException as ex: logger.exception(f"Issue with setting weights: {ex}") augmented_weights = weights @@ -60,8 +58,7 @@ async def set_weights( # Convert to uint16 weights and uids. uint_uids, uint_weights = bt.utils.weight_utils.convert_weights_and_uids_for_emit( - uids=processed_weight_uids, - weights=processed_weights + uids=processed_weight_uids, weights=processed_weights ) except Exception as ex: logger.exception(f"Skipping weight setting: {ex}") @@ -167,9 +164,7 @@ async def _load_rewards(self): if payload is None: raise ValueError(f"Malformed weight history file: {data}") - self.reward_history.append( - {int(uid): {"reward": float(reward)} for uid, reward in payload.items()} - ) + self.reward_history.append({int(uid): {"reward": float(reward)} for uid, reward in payload.items()}) except BaseException as exc: self.reward_history: deque[dict[int, dict[str, Any]]] | None = deque(maxlen=self.reward_history_len) logger.error(f"Couldn't load rewards from file, resetting weight history: {exc}") @@ -217,8 +212,7 @@ async def merge_task_rewards(cls, reward_events: list[list[WeightedRewardEvent]] processed_rewards = task_rewards / max(1, (np.sum(task_rewards[task_rewards > 0]) + 1e-10)) else: processed_rewards = cls.apply_steepness( - raw_rewards=task_rewards, - steepness=shared_settings.REWARD_STEEPNESS + raw_rewards=task_rewards, steepness=shared_settings.REWARD_STEEPNESS ) processed_rewards *= task_config.probability @@ -238,11 +232,11 @@ def apply_steepness(cls, raw_rewards: npt.NDArray[np.float32], steepness: float p > 0.5 makes the function more exponential (winner takes all). """ # 6.64385619 = ln(100)/ln(2) -> this way if p = 0.5, the exponent is exactly 1. - exponent = (steepness ** 6.64385619) * 100 + exponent = (steepness**6.64385619) * 100 raw_rewards = np.array(raw_rewards) / max(1, (np.sum(raw_rewards[raw_rewards > 0]) + 1e-10)) positive_rewards = np.clip(raw_rewards, 1e-10, np.inf) normalised_rewards = positive_rewards / np.max(positive_rewards) - post_func_rewards = normalised_rewards ** exponent + post_func_rewards = normalised_rewards**exponent all_rewards = post_func_rewards / (np.sum(post_func_rewards) + 1e-10) all_rewards[raw_rewards <= 0] = raw_rewards[raw_rewards <= 0] return all_rewards @@ -251,13 +245,13 @@ async def run_step(self): await asyncio.sleep(0.01) try: if self.reward_events is None: - logger.error(f"No rewards events were found, skipping weight setting") + logger.error("No rewards events were found, skipping weight setting") return final_rewards = await self.merge_task_rewards(self.reward_events) if final_rewards is None: - logger.error(f"No rewards were found, skipping weight setting") + logger.error("No rewards were found, skipping weight setting") return await self._save_rewards(final_rewards)