# Generalized Advantage Estimation(GAE): Robust advantage estimation #
resources:
- https://lilianweng.github.io/posts/2018-02-19-rl-overview/#combining-td-and-mc-learning
- DeepMimic: Example-Guided Deep Reinforcement Learning of Physics-Based Character Skills (https://arxiv.org/abs/1804.02717)
   - only look at the Supplementary Material A
- https://github.com/mimoralea/gdrl/blob/master/notebooks/chapter_11/chapter-11.ipynb
- https://xlnwel.github.io/blog/reinforcement%20learning/GAE/

In [1]:
import torch
import numpy as np

In [2]:
# GAE's optimize model
def gae_optimize_model(logpas, rewards, values, gamma, tau):
    T = len(rewards)
    discounts = np.logspace(0, T, num=T, base=gamma, endpoint=False)
    returns = np.array([np.sum(discounts[:T-t] * rewards[t:]) for t in range(T)])

    logpas = torch.Tensor(logpas)
    # entropies = torch.cat(entropies)
    values = torch.Tensor(values)

    np_values = values.view(-1).data.numpy()
    tau_discounts = np.logspace(0, T-1, num=T-1, base=gamma*tau, endpoint=False)
    advs = rewards[:-1] + gamma * np_values[1:] - np_values[:-1]
    gaes = np.array([np.sum(tau_discounts[:T-1-t] * advs[t:]) for t in range(T-1)])

    values = values[:-1,...]
    discounts = torch.FloatTensor(discounts[:-1])
    returns = torch.FloatTensor(returns[:-1])
    gaes = torch.FloatTensor(gaes)
    # # [T-1, T] prior to averaging
    policy_loss = -(discounts * gaes.detach() * logpas).mean()
    # skipping entropy loss entropy_loss = -entropies.mean()

    value_error = returns - values
    value_loss = value_error.pow(2).mul(0.5).mean()
    return gaes, policy_loss, value_loss

In [3]:
# set dummy inputs and get reference output
rewards = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 10])
logpas = np.array([.2, .2, .2, .2, .2, .2, .2, .2, .2])
values = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
gamma = 0.9
tau = 0.5
ref_gae, ref_policy_loss, ref_value_loss = gae_optimize_model(logpas, rewards, values, gamma, tau)

In [4]:
# set out the discount vector
T = len(rewards)
discounts = np.logspace(0, T, num=T, base=gamma, endpoint=False)
discounts

array([1.        , 0.9       , 0.81      , 0.729     , 0.6561    ,
       0.59049   , 0.531441  , 0.4782969 , 0.43046721, 0.38742049])

In [5]:
# get discounted returns
disc_returns = []
for t in range(T):
    print(t, discounts[:T-t], rewards[t:])
    disc_returns.append(np.sum(discounts[:T-t] * rewards[t:]))
disc_returns = np.array(disc_returns)
print("discounted returns:")
print(disc_returns)

0 [1.         0.9        0.81       0.729      0.6561     0.59049
 0.531441   0.4782969  0.43046721 0.38742049] [ 0  0  0  0  0  0  0  0  0 10]
1 [1.         0.9        0.81       0.729      0.6561     0.59049
 0.531441   0.4782969  0.43046721] [ 0  0  0  0  0  0  0  0 10]
2 [1.        0.9       0.81      0.729     0.6561    0.59049   0.531441
 0.4782969] [ 0  0  0  0  0  0  0 10]
3 [1.       0.9      0.81     0.729    0.6561   0.59049  0.531441] [ 0  0  0  0  0  0 10]
4 [1.      0.9     0.81    0.729   0.6561  0.59049] [ 0  0  0  0  0 10]
5 [1.     0.9    0.81   0.729  0.6561] [ 0  0  0  0 10]
6 [1.    0.9   0.81  0.729] [ 0  0  0 10]
7 [1.   0.9  0.81] [ 0  0 10]
8 [1.  0.9] [ 0 10]
9 [1.] [10]
discounted returns:
[ 3.87420489  4.3046721   4.782969    5.31441     5.9049      6.561
  7.29        8.1         9.         10.        ]


In [6]:
"""tau discounts:
here, tau=0 means one step return,
"""
tau_discounts = np.logspace(0, T-1, num=T-1, base=0*tau, endpoint=False)
tau_discounts

array([1., 0., 0., 0., 0., 0., 0., 0., 0.])

In [7]:
"""tau discounts:
here, tau=1 means infinite return,
and tau serves to adjust between the two extremes
"""
tau_discounts = np.logspace(0, T-1, num=T-1, base=1*tau, endpoint=False)
tau_discounts

array([1.        , 0.5       , 0.25      , 0.125     , 0.0625    ,
       0.03125   , 0.015625  , 0.0078125 , 0.00390625])

In [8]:
# using the saved tau value from above from here
tau_discounts = np.logspace(0, T-1, num=T-1, base=gamma*tau, endpoint=False)
print("tau discounts", tau_discounts)

tau discounts [1.         0.45       0.2025     0.091125   0.04100625 0.01845281
 0.00830377 0.00373669 0.00168151]


In [9]:
# the advantage calculation:
print(rewards[:-1])
print(values[1:], gamma * values[1:])
print(values[:-1])

[0 0 0 0 0 0 0 0 0]
[1 1 1 1 1 1 1 1 1] [0.9 0.9 0.9 0.9 0.9 0.9 0.9 0.9 0.9]
[1 1 1 1 1 1 1 1 1]


In [10]:
# GAEs are but discounted sum of td errors...
# td errors: R+t + gamma*value_t+1 - value_t for t=0 to T
advs = rewards[:-1] + gamma * values[1:] - values[:-1]
print("advantages", advs)

advantages [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1]


In [11]:
# get discounted returns
gaes = []
for t in range(T-1):
    print(t, tau_discounts[:T-1-t], advs[t:])
    gaes.append(np.sum(tau_discounts[:T-1-t] * advs[t:]))
gaes = np.array(gaes)
print("gaes returns:")
print(gaes)

0 [1.         0.45       0.2025     0.091125   0.04100625 0.01845281
 0.00830377 0.00373669 0.00168151] [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
1 [1.         0.45       0.2025     0.091125   0.04100625 0.01845281
 0.00830377 0.00373669] [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
2 [1.         0.45       0.2025     0.091125   0.04100625 0.01845281
 0.00830377] [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
3 [1.         0.45       0.2025     0.091125   0.04100625 0.01845281] [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
4 [1.         0.45       0.2025     0.091125   0.04100625] [-0.1 -0.1 -0.1 -0.1 -0.1]
5 [1.       0.45     0.2025   0.091125] [-0.1 -0.1 -0.1 -0.1]
6 [1.     0.45   0.2025] [-0.1 -0.1 -0.1]
7 [1.   0.45] [-0.1 -0.1]
8 [1.] [-0.1]
gaes returns:
[-0.1816806  -0.18151245 -0.18113878 -0.18030841 -0.17846312 -0.1743625
 -0.16525    -0.145      -0.1       ]


In [12]:
# the gae and reference gae should match
print(gaes)
print(np.array(ref_gae).squeeze())

[-0.1816806  -0.18151245 -0.18113878 -0.18030841 -0.17846312 -0.1743625
 -0.16525    -0.145      -0.1       ]
[-0.18168065 -0.18151249 -0.18113883 -0.18030845 -0.17846316 -0.17436254
 -0.16525003 -0.14500004 -0.10000002]


In [13]:
policy_loss = -np.mean(discounts[:-1] * gaes * logpas)
print(policy_loss)
print(ref_policy_loss)

0.02318840929956597
tensor(0.0232)


In [14]:
criterion = torch.nn.MSELoss()
my_loss = criterion(torch.Tensor(disc_returns[:-1]), torch.Tensor(values[:-1]))
# miguel uses an extra 0.5 compared to regular MSE loss
print(my_loss/2)
print(ref_value_loss)

tensor(14.5035)
tensor(14.5035)
