Skip to content

Commit

Permalink
Merge pull request #851 from dstl/printing_reward
Browse files Browse the repository at this point in the history
Bool to return reward for BruteForce and Optimise sensor managers
  • Loading branch information
sdhiscocks committed Oct 3, 2023
2 parents 17e96b6 + 6a11d38 commit 6000941
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
20 changes: 13 additions & 7 deletions stonesoup/sensormanager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class BruteForceSensorManager(SensorManager):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def choose_actions(self, tracks, timestamp, nchoose=1, **kwargs):
def choose_actions(self, tracks, timestamp, nchoose=1, return_reward=False, **kwargs):
"""Returns a chosen [list of] action(s) from the action set for each sensor.
Chosen action(s) is selected by finding the configuration of sensors: actions which returns
the maximum reward, as calculated by a reward function.
Expand All @@ -111,11 +111,14 @@ def choose_actions(self, tracks, timestamp, nchoose=1, **kwargs):
Time at which the actions are carried out until
nchoose : int
Number of actions from the set to choose (default is 1)
return_reward: bool
Whether to return the reward for chosen actions (default is False)
When True, returns a tuple of 1d arrays: (dictionaries of chosen actions, rewards)
Returns
-------
: dict
The pairs of :class:`~.Sensor`: [:class:`~.Action`] selected
: list(dict) or (list(dict), :class:`numpy.ndarray`)
The pairs of :class:`~.Sensor`: [:class:`~.Action`] selected and the array contains
the corresponding reward.
"""

all_action_choices = dict()
Expand All @@ -141,6 +144,9 @@ def choose_actions(self, tracks, timestamp, nchoose=1, **kwargs):
if reward > min(best_rewards):
selected_configs[np.argmin(best_rewards)] = config
best_rewards[np.argmin(best_rewards)] = reward

# Return mapping of sensors and chosen actions for sensors
return selected_configs
if return_reward:
# Return mapping of sensors and chosen actions for sensors
# Also returns rewards
return selected_configs, best_rewards
else:
return selected_configs
9 changes: 6 additions & 3 deletions stonesoup/sensormanager/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class _OptimizeSensorManager(BruteForceSensorManager):
def _optimiser(self, optimise_func, all_action_generators):
raise NotImplementedError

def choose_actions(self, tracks, timestamp, nchoose=1, **kwargs):
def choose_actions(self, tracks, timestamp, nchoose=1, return_reward=False, **kwargs):
if nchoose > 1:
raise ValueError("Can only return best result (nchoose=1)")
all_action_generators = dict()
Expand All @@ -38,8 +38,11 @@ def optimise_func(x):

best_x = self._optimiser(optimise_func, all_action_generators)
config = config_from_x(best_x)

return [config]
if return_reward:
reward = self.reward_function(config, tracks, timestamp)
return [config], reward
else:
return [config]


class OptimizeBruteSensorManager(_OptimizeSensorManager):
Expand Down

0 comments on commit 6000941

Please sign in to comment.