In [30]:
import numpy as np

In [31]:
# sample states, actions and rewards
states_batch = [
    [1, 2, 3],     # game1
    [4, 2, 0, 2],  # game2
    [3, 1],        # game3
]

actions_batch = [
    [0, 2, 4],     # game1
    [3, 2, 0, 1],  # game2
    [3, 3],        # game3
]
rewards_batch = [
    3,  # game1
    4,  # game2
    5,  # game3
]

In [32]:
# find best actions
def select_elites(states_batch, actions_batch, rewards_batch, percentile):

    reward_threshold = np.percentile(rewards_batch, percentile)
    elite_states = []
    elite_actions= []

    for idx, reward in enumerate(rewards_batch):
        if reward >= reward_threshold:
            elite_states.extend(states_batch[idx])
            elite_actions.extend(actions_batch[idx])

    return elite_states, elite_actions

In [33]:
test_result_0 = select_elites(states_batch, actions_batch, rewards_batch, percentile=0)
print(test_result_0)
test_result_30 = select_elites(states_batch, actions_batch, rewards_batch, percentile=30)
print(test_result_30)
test_result_90 = select_elites(states_batch, actions_batch, rewards_batch, percentile=90)
print(test_result_90)
test_result_100 = select_elites(states_batch, actions_batch, rewards_batch, percentile=100)
print(test_result_100)

([1, 2, 3, 4, 2, 0, 2, 3, 1], [0, 2, 4, 3, 2, 0, 1, 3, 3])
([4, 2, 0, 2, 3, 1], [3, 2, 0, 1, 3, 3])
([3, 1], [3, 3])
([3, 1], [3, 3])


In [34]:
# verify results
assert np.all(test_result_0[0] == [1, 2, 3, 4, 2, 0, 2, 3, 1]) \
       and np.all(test_result_0[1] == [0, 2, 4, 3, 2, 0, 1, 3, 3]), \
    "For percentile 0 you should return all states and actions in chronological order"
assert np.all(test_result_30[0] == [4, 2, 0, 2, 3, 1]) and \
       np.all(test_result_30[1] == [3, 2, 0, 1, 3, 3]), \
    "For percentile 30 you should only select states/actions from two first"
assert np.all(test_result_90[0] == [3, 1]) and \
       np.all(test_result_90[1] == [3, 3]), \
    "For percentile 90 you should only select states/actions from one game"
assert np.all(test_result_100[0] == [3, 1]) and \
       np.all(test_result_100[1] == [3, 3]), \
    "Please make sure you use >=, not >. Also double-check how you compute percentile."

print("Ok!")

Ok!


In [35]:
# find best actions 2
def select_elites_2(states_batch, actions_batch, rewards_batch, percentile):

    reward_threshold = np.percentile(rewards_batch, percentile)

    states_a = np.array(states_batch)
    actions_a = np.array(actions_batch)
    rewards_a = np.array(rewards_batch)

    elite_states_a = states_a[rewards_a >= reward_threshold]
    elite_actions_a = actions_a[rewards_a >= reward_threshold]

    elite_states = []
    elite_actions = []
    
    for i in range(len(elite_states_a)):
        elite_states += elite_states_a[i]
        elite_actions += elite_actions_a[i]

    return elite_states, elite_actions

In [36]:
test_result_0 = select_elites_2(states_batch, actions_batch, rewards_batch, percentile=0)
print(test_result_0)
test_result_30 = select_elites_2(states_batch, actions_batch, rewards_batch, percentile=30)
print(test_result_30)
test_result_90 = select_elites_2(states_batch, actions_batch, rewards_batch, percentile=90)
print(test_result_90)
test_result_100 = select_elites_2(states_batch, actions_batch, rewards_batch, percentile=100)
print(test_result_100)

([1, 2, 3, 4, 2, 0, 2, 3, 1], [0, 2, 4, 3, 2, 0, 1, 3, 3])
([4, 2, 0, 2, 3, 1], [3, 2, 0, 1, 3, 3])
([3, 1], [3, 3])
([3, 1], [3, 3])


In [37]:
# verify results 2
assert np.all(test_result_0[0] == [1, 2, 3, 4, 2, 0, 2, 3, 1]) \
       and np.all(test_result_0[1] == [0, 2, 4, 3, 2, 0, 1, 3, 3]), \
    "For percentile 0 you should return all states and actions in chronological order"
assert np.all(test_result_30[0] == [4, 2, 0, 2, 3, 1]) and \
       np.all(test_result_30[1] == [3, 2, 0, 1, 3, 3]), \
    "For percentile 30 you should only select states/actions from two first"
assert np.all(test_result_90[0] == [3, 1]) and \
       np.all(test_result_90[1] == [3, 3]), \
    "For percentile 90 you should only select states/actions from one game"
assert np.all(test_result_100[0] == [3, 1]) and \
       np.all(test_result_100[1] == [3, 3]), \
    "Please make sure you use >=, not >. Also double-check how you compute percentile."

print("Ok!")

Ok!
