@@ -75,6 +75,8 @@ GradientBanditSearch::GradientBanditSearch(EnvWrapper orig_env, A2CLearner a2c_a
7575 horizon = std::min (horizon, orig_env.env ->max_steps );
7676 double alpha = params[" dirichlet_alpha" ];
7777 double frac = params[" dirichlet_frac" ];
78+ bool do_init_random = params[" grad_bandit_init_random" ];
79+ reward_power = params[" grad_bandit_reward_power" ];
7880
7981 EnvWrapper env_ = *orig_env.clone ();
8082
@@ -115,15 +117,18 @@ GradientBanditSearch::GradientBanditSearch(EnvWrapper orig_env, A2CLearner a2c_a
115117 break ;
116118 }
117119
118- // In case we evaluated a very good path, add missing bandits with random initialization.
120+ // In case we evaluated a very good path, add missing bandits (optionally with random
121+ // initialization).
119122 std::uniform_real_distribution<double > distribution_ (0.0 , 1.0 );
120- for (int j = 0 ; j < (horizon - i - 1 ); ++j) {
121- std::vector<double > vec (n_actions);
122- std::generate (
123- vec.begin (),
124- vec.end (),
125- [distribution_, this ] () mutable { return distribution_ (this ->generator ); }
126- );
123+ for (int j = 0 ; j < (horizon - i); ++j) {
124+ std::vector<double > vec{0.33 , 0.33 , 0.33 };
125+ if (do_init_random) {
126+ std::generate (
127+ vec.begin (),
128+ vec.end (),
129+ [distribution_, this ] () mutable { return distribution_ (this ->generator ); }
130+ );
131+ }
127132
128133 // Create the bandit.
129134 auto bandit = SingleGradientBandit (params);
@@ -142,8 +147,8 @@ GradientBanditSearch::policy(int i, EnvWrapper orig_env, std::vector<float> obs,
142147
143148 EnvWrapper env = *orig_env.clone ();
144149
145- // TODO Hmm.. It could be that horizon is set HIGHER than the maximum horizon of the
146- // environment. So, let's only loop until the size of bandits.
150+ // It could be that horizon is set higher than the maximum horizon of the environment.
151+ // So let's only loop until the size of bandits.
147152 int j = i;
148153 for (; j < bandits.size (); ++j) {
149154 std::vector<double > action_probs;
@@ -157,10 +162,15 @@ GradientBanditSearch::policy(int i, EnvWrapper orig_env, std::vector<float> obs,
157162 bool done;
158163 std::tie (std::ignore, reward, done) = env.step (action);
159164
165+ reward = std::pow (reward, reward_power);
160166 rewards.push_back (reward);
161167
162- if (done)
168+ if (done) {
169+ // Since we break, the last ++j of the loop is not executed.
170+ // To keep things consistent later on, let's do it manually.
171+ j += 1 ;
163172 break ;
173+ }
164174 }
165175
166176 std::vector<double > cumulative_rewards;
@@ -171,7 +181,9 @@ GradientBanditSearch::policy(int i, EnvWrapper orig_env, std::vector<float> obs,
171181 }
172182 std::reverse (cumulative_rewards.begin (), cumulative_rewards.end ());
173183
174- for (int m = 0 ; m < j - i; ++m) {
184+ // This had an off by one mistake. Refer to j += 1 a few lines above.
185+ int size = std::min ((int ) bandits.size () - 1 , j - i);
186+ for (int m = 0 ; m < size; ++m) {
175187 bandits[m + i].update (actions_probs_arr[m], actions[m], cumulative_rewards[m]);
176188 }
177189 }
0 commit comments