# Stimare Q function ottimale
## 1. Quanto tempo ci mette per convergere a Q*, al variare di tau
## 2. Distanza tra Q* e Q_t 
## 3. distanza tra Q_0 e Q_0 appresa tramite i valori di tau
## 4. cercare di integrare i bound trovati

### Estrarre policy da Q
### Confrontare due diverse Q valutando diverse metriche 
### Calcolare Q*
### Tentativo di implementazione di un curriculum
### Plot functions

In [1]:
import numpy as np
from TMDP import TMDP
from river_swim import River

from algorithms import Q_learning, eps_greedy, SARSA, get_policy


In [2]:
# Test with tau=0.9
gamma = 0.9
river = River(gamma)
tau = 0.9
xi = np.ones(river.nS)*1/river.nS
tmdp = TMDP(river, xi, tau, gamma)


In [3]:
s = tmdp.reset()
M = 1000
Q = np.zeros((tmdp.nS, tmdp.nA))
ret = 0

a = eps_greedy(s, Q, 1., tmdp.allowed_actions[s.item()])

Q = Q_learning(tmdp, s, a, Q, M)

print(Q)

[[2295.37932017 1771.46170777]
 [1726.62245786 2045.61375673]
 [2151.43111671 1832.26507719]
 [1997.94487062 2002.17853929]
 [1834.03067288 2320.61042952]
 [2056.1261064  3683.61124103]]


In [4]:
tmdp_0 = TMDP(river, xi, 0, gamma)
s = tmdp_0.reset()
print(tmdp_0.tau)
M = 10000
Q = np.zeros((tmdp_0.nS, tmdp_0.nA))
ret = 0
a = eps_greedy(s, Q, 1., tmdp_0.allowed_actions[s.item()])

Q = Q_learning(tmdp_0, s, a, Q, M)

print(Q)

0
[[50.         43.27587434]
 [45.         39.98967386]
 [40.5        36.43253298]
 [36.44998982 32.81060079]
 [ 0.          0.        ]
 [ 0.          0.        ]]


In [27]:
taus = [1 - i*0.1 for i in range(10)]
taus.append(0)
Qs = []
xi = np.ones(river.nS)*1/river.nS
gamma = 0.9
river = River(gamma)

for tau in taus:
    tmdp = TMDP(river, xi, tau, gamma)
    s = tmdp.reset()
    M = 10000
    Q = np.zeros((tmdp.nS, tmdp.nA))
    ret = 0
    a = eps_greedy(s, Q, 1., tmdp.allowed_actions[s.item()])
    Q = Q_learning(tmdp, s, a, Q, M)
    Qs.append({"tau":tau, "Q_function":Q, "env":tmdp})

for i in range(len(taus)):
    print("Tau:", Qs[i]['tau'])
    print(Qs[i]['Q_function'])

Tau: 1.0
[[2124.93125414 1426.75457524]
 [1282.34522533 2048.21470816]
 [2048.09712304 1297.13325371]
 [1293.37125046 2020.42231722]
 [2116.91019724 1335.44819086]
 [1621.55150565 3723.7081449 ]]
Tau: 0.9
[[1633.47114911 1270.13214021]
 [1352.07539124 1646.19680538]
 [1282.930951   1711.75002275]
 [1637.92856007 1312.66684541]
 [1770.0713257  1397.78026753]
 [1409.56219363 3709.70044264]]
Tau: 0.8
[[1779.11406064 2268.59196207]
 [2310.67685398 2243.80579636]
 [2335.36709694 2189.31621153]
 [1926.95504549 2427.32261647]
 [2399.35624651 1763.69978777]
 [1762.78809147 4790.96143248]]
Tau: 0.7
[[1082.49928273 1904.03419369]
 [1809.19616388 1580.23054097]
 [1850.4348031  1356.34555757]
 [1341.38592883 1798.80929933]
 [1644.6028088  2053.13914765]
 [1381.75655136 3575.34417401]]
Tau: 0.6
[[2477.42839534 2648.94972141]
 [1508.97376608 2514.4012775 ]
 [2497.83786425 2375.60366383]
 [2487.84558699  897.6957089 ]
 [1237.26483953 3120.77412311]
 [ 933.89073887 5431.46831509]]
Tau: 0.5
[[ 914.3485

In [28]:
pi = get_policy(Qs[-2]['Q_function'])
pi_prime = get_policy(Qs[-1]['Q_function'])

print(pi)
print(pi_prime)

[1, 1, 1, 1, 1, 1]
[0, 0, 0, 0, 0, 0]


In [31]:
tmdp = TMDP(river, xi, 1, gamma)
Q = np.zeros((tmdp.nS, tmdp.nA))

# Curriculul for decreasing values of tau
for tau in taus:
    tmdp = TMDP(river, xi, tau, gamma)
    s = tmdp.reset()
    M = 10000
    if tau == 0:
        M = 100000
    ret = 0
    a = eps_greedy(s, Q, 1., tmdp.allowed_actions[s.item()])
    Q = Q_learning(tmdp, s, a, Q, M)


print("Q function learned from scratch:\n", Qs[-1]['Q_function'])
print("Q function learned with transfer learning:\n", Q)

pi = get_policy(Qs[-1]['Q_function'])
pi_prime = get_policy(Q)

print(pi)
print(pi_prime)


Q function learned from scratch:
 [[50.         44.28999532]
 [45.         42.31003587]
 [40.5        35.33281656]
 [36.44999989 28.45936146]
 [32.11308791  0.        ]
 [ 0.          0.        ]]
Q function learned with transfer learning:
 [[ 499.28890249 1028.31069227]
 [ 567.95346674 1726.16298339]
 [ 892.75431572 2841.44220435]
 [1894.15901516 4153.01197098]
 [3472.76210858 6413.56818659]
 [4349.12037286 9946.95186267]]
[0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1]
