In [1]:
import numpy as np
from TMDP import TMDP
from River_swim import River

from algorithms import *
from model_functions import *

import matplotlib.pyplot as plt
import autograd.numpy as np
from autograd import grad, jacobian, elementwise_grad

#np.set_printoptions(precision=4)
import math
from utils import *

nS = 8
nA = 2
seed = 2184109
gamma = .9
mu = np.ones(nS) * 1/nS
river = River(nS, mu, gamma=gamma, small=5, large=1000, seed=seed)
tau = 0.3
xi = np.ones(nS) * 1/nS
tmdp = TMDP(river, xi, tau=tau, gamma=gamma, seed=seed)
mdp = TMDP(river, xi, tau=0., gamma=gamma, seed=seed)
  # argnum=5 corresponds to the position of 'tau'



In [2]:
def numerical_gradient(func, mu, P_mat, xi, pi, gamma, tau, h=1e-5):
    return (func(mu, P_mat, xi, pi, gamma, tau + h) - func(mu, P_mat, xi, pi, gamma, tau - h)) / (2 * h)

In [3]:
taus = [0.99, 0.98, 0.95, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.]
for tau in taus:
    print("tau", tau)
    tmdp = TMDP(river, xi, tau=tau, gamma=gamma, seed=seed)
    Q_star = bellman_optimal_q(tmdp.P_mat_tau, tmdp.reward, tmdp.gamma)["Q"]
    pi = get_policy(Q_star)
    d = compute_d_from_tau(tmdp.mu, tmdp.P_mat, tmdp.xi, pi, tmdp.gamma, tmdp.tau)
    print(d)
    grad_d = compute_grad_d(tmdp.P_mat, tmdp.P_mat_tau, tmdp.xi, tmdp.mu, pi, tmdp.gamma)
    print(grad_d)
    numerical_grad = numerical_gradient(compute_d_from_tau, tmdp.mu, tmdp.P_mat, tmdp.xi, pi, tmdp.gamma, tmdp.tau)
    print(numerical_grad)


tau 0.99
[0.12613522 0.12499992 0.12499088 0.12398719 0.12466067 0.12499969
 0.12567582 0.12455061]
[ 1.14552371e-01 -2.46136584e-05 -1.82332607e-03 -1.01311246e-01
 -3.41175586e-02 -6.25206657e-05  6.76642338e-02 -4.48773396e-02]
[-1.14552371e-01  2.46136667e-05  1.82332607e-03  1.01311246e-01
  3.41175586e-02  6.25206643e-05 -6.76642338e-02  4.48773396e-02]
tau 0.98
[0.12729123 0.12499934 0.12496353 0.12297377 0.12431763 0.12499874
 0.1263533  0.12410246]
[ 1.16659617e-01 -9.84948092e-05 -3.64832840e-03 -1.01373498e-01
 -3.44913324e-02 -1.28668567e-04  6.78330142e-02 -4.47523087e-02]
[-1.16659617e-01  9.84948158e-05  3.64832840e-03  1.01373498e-01
  3.44913324e-02  1.28668565e-04 -6.78330142e-02  4.47523087e-02]
tau 0.95
[0.13088957 0.12498973 0.12477184 0.1199297  0.12326563 0.12499164
 0.1283962  0.12276569]
[ 0.12331271 -0.00061637 -0.00913377 -0.10156652 -0.03565148 -0.00035023
  0.06836787 -0.04436221]
[-0.12331271  0.00061637  0.00913377  0.10156652  0.03565148  0.00035023
 -0.