Skip to content

Commit 3385916

Browse files
committed
Close #1: Fix off by one mistake and add two more options to gradient bandit: grad_bandit_init_random and grad_bandit_reward_power.
1 parent c3416e3 commit 3385916

3 files changed

Lines changed: 40 additions & 15 deletions

File tree

src/cfg.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ json get_default(std::string base) {
5151
{"n_desired_eval_len", 10},
5252
{"bandit_type", "mcts"}, // mcts, grad
5353
{"grad_bandit_alpha", 0.01},
54-
{"use_eps_greedy_learning", false},
54+
{"use_eps_greedy_learning", true},
55+
{"eps_greedy_epsilon_decay_factor_train", 0.995},
56+
{"eps_greedy_epsilon_decay_factor_actor", 0.995},
57+
{"grad_bandit_init_random", true},
58+
{"grad_bandit_reward_power", 1},
5559

5660
// Other
5761
{"reward_exponent", 1},
@@ -101,7 +105,11 @@ json get_default(std::string base) {
101105
{"n_desired_eval_len", 10},
102106
{"bandit_type", "mcts"}, // mcts, grad
103107
{"grad_bandit_alpha", 0.01},
104-
{"use_eps_greedy_learning", false},
108+
{"use_eps_greedy_learning", true},
109+
{"eps_greedy_epsilon_decay_factor_train", 0.995},
110+
{"eps_greedy_epsilon_decay_factor_actor", 0.995},
111+
{"grad_bandit_init_random", true},
112+
{"grad_bandit_reward_power", 1},
105113

106114
// Other
107115
{"reward_exponent", 1},
@@ -146,7 +154,11 @@ json get_default(std::string base) {
146154
{"n_desired_eval_len", 100},
147155
{"bandit_type", "mcts"}, // mcts, grad
148156
{"grad_bandit_alpha", 0.01},
149-
{"use_eps_greedy_learning", false},
157+
{"use_eps_greedy_learning", true},
158+
{"eps_greedy_epsilon_decay_factor_train", 0.995},
159+
{"eps_greedy_epsilon_decay_factor_actor", 0.995},
160+
{"grad_bandit_init_random", true},
161+
{"grad_bandit_reward_power", 1},
150162

151163
// Other
152164
{"reward_exponent", 1},

src/gradient_bandit.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ GradientBanditSearch::GradientBanditSearch(EnvWrapper orig_env, A2CLearner a2c_a
7575
horizon = std::min(horizon, orig_env.env->max_steps);
7676
double alpha = params["dirichlet_alpha"];
7777
double frac = params["dirichlet_frac"];
78+
bool do_init_random = params["grad_bandit_init_random"];
79+
reward_power = params["grad_bandit_reward_power"];
7880

7981
EnvWrapper env_ = *orig_env.clone();
8082

@@ -115,15 +117,18 @@ GradientBanditSearch::GradientBanditSearch(EnvWrapper orig_env, A2CLearner a2c_a
115117
break;
116118
}
117119

118-
// In case we evaluated a very good path, add missing bandits with random initialization.
120+
// In case we evaluated a very good path, add missing bandits (optionally with random
121+
// initialization).
119122
std::uniform_real_distribution<double> distribution_(0.0, 1.0);
120-
for (int j = 0; j < (horizon - i - 1); ++j) {
121-
std::vector<double> vec(n_actions);
122-
std::generate(
123-
vec.begin(),
124-
vec.end(),
125-
[distribution_, this] () mutable { return distribution_(this->generator); }
126-
);
123+
for (int j = 0; j < (horizon - i); ++j) {
124+
std::vector<double> vec{0.33, 0.33, 0.33};
125+
if (do_init_random) {
126+
std::generate(
127+
vec.begin(),
128+
vec.end(),
129+
[distribution_, this] () mutable { return distribution_(this->generator); }
130+
);
131+
}
127132

128133
// Create the bandit.
129134
auto bandit = SingleGradientBandit(params);
@@ -142,8 +147,8 @@ GradientBanditSearch::policy(int i, EnvWrapper orig_env, std::vector<float> obs,
142147

143148
EnvWrapper env = *orig_env.clone();
144149

145-
// TODO Hmm.. It could be that horizon is set HIGHER than the maximum horizon of the
146-
// environment. So, let's only loop until the size of bandits.
150+
// It could be that horizon is set higher than the maximum horizon of the environment.
151+
// So let's only loop until the size of bandits.
147152
int j = i;
148153
for (; j < bandits.size(); ++j) {
149154
std::vector<double> action_probs;
@@ -157,10 +162,15 @@ GradientBanditSearch::policy(int i, EnvWrapper orig_env, std::vector<float> obs,
157162
bool done;
158163
std::tie(std::ignore, reward, done) = env.step(action);
159164

165+
reward = std::pow(reward, reward_power);
160166
rewards.push_back(reward);
161167

162-
if (done)
168+
if (done) {
169+
// Since we break, the last ++j of the loop is not executed.
170+
// To keep things consistent later on, let's do it manually.
171+
j += 1;
163172
break;
173+
}
164174
}
165175

166176
std::vector<double> cumulative_rewards;
@@ -171,7 +181,9 @@ GradientBanditSearch::policy(int i, EnvWrapper orig_env, std::vector<float> obs,
171181
}
172182
std::reverse(cumulative_rewards.begin(), cumulative_rewards.end());
173183

174-
for (int m = 0; m < j - i; ++m) {
184+
// This had an off by one mistake. Refer to j += 1 a few lines above.
185+
int size = std::min((int) bandits.size() - 1, j - i);
186+
for (int m = 0; m < size; ++m) {
175187
bandits[m + i].update(actions_probs_arr[m], actions[m], cumulative_rewards[m]);
176188
}
177189
}

src/gradient_bandit.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class GradientBanditSearch : public Bandit {
3636
int n_iter;
3737
int horizon;
3838
std::mt19937 generator;
39+
int reward_power;
3940

4041
std::vector<SingleGradientBandit> bandits;
4142
EnvWrapper env;

0 commit comments

Comments
 (0)