# Evaluación y mejora

<img src="evaluacionymejora.png">
<img src="mejora.png">

# Iteración de política:(del libro de Sutton)

<img src="iteracion de politica.PNG">

Implementar el algoritmo de mejora de iteración de política.

In [1]:
import numpy as np
import pprint
from lib.envs.gridworld import GridworldEnv

In [2]:
pp = pprint.PrettyPrinter(indent=2)
env = GridworldEnv()

In [3]:
env.P[1]

{0: [(1.0, 1, -1.0, False)],
 1: [(1.0, 2, -1.0, False)],
 2: [(1.0, 5, -1.0, False)],
 3: [(1.0, 0, -1.0, True)]}

In [45]:
# Del ejercicio de evaluación de política

def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    """
    Evaluar una política dado un ambiente y una descripción completa
    de la dinámica del ambiente.
    
    Argumentos:
        política: matriz de tamaño [S, A] representando la política.
        env: ambiente de OpenAI representadno las probabilidades de transición
        del ambiente. 
        env.P[s][a] es una lista de vectores (probabilidad, próximo_estado, recompensa, done)
        env.nS es el número de estados en el ambiente
        env.nA es el número de acciones en el ambiente
        theta: para la evaluación de la política una vez que la función de valor cambia menos que
        theta para todos los estados
        discount_factor: factor de descuento gama.
        
    Retorna:
        Vector de longitud env.nS que representa la función de valor.
    """
    # Comenzar con función de valor aleatoria
    V = np.zeros(env.nS)
    while True:
        delta = 0
        # Para cada estado realizar un "full backup"
        for s in range(env.nS):
            print(f"\n S = {s}")
            v = 0
            # Fijarse en las posibles próximas acciones
            for a, action_prob in enumerate(policy[s]):
                # Para cada acción, fijarse en los próximos estados
                for  prob, next_state, reward, done in env.P[s][a]:
                    # Calcular el valor esperado
                    print(V[s])
                    print(f"next = {next_state} | V[next] = {V[next_state]}")
                    v += action_prob * prob * (reward + discount_factor * V[next_state])
                    print(V[s])
    
                print(f"s = {s} | a = {a} | v = {V[s]} | v(s) = {v}")
            # Cuál fue el máximo cambio de la función de valor
            delta = max(delta, np.abs(v - V[s]))
            V[s] = v
            print(V[s])
        # Para de evaluar una vez que estamos debajo de un cierto umbral
        if delta < theta:
            break
    return np.array(V)

In [46]:
policy = np.ones([env.nS, env.nA]) / env.nA
v = policy_eval(policy, env)
v


 S = 0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 0 | v = 0.0 | v(s) = 0.0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 1 | v = 0.0 | v(s) = 0.0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 2 | v = 0.0 | v(s) = 0.0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 3 | v = 0.0 | v(s) = 0.0
0.0

 S = 1
0.0
next = 1 | V[next] = 0.0
0.0
s = 1 | a = 0 | v = 0.0 | v(s) = -0.25
0.0
next = 2 | V[next] = 0.0
0.0
s = 1 | a = 1 | v = 0.0 | v(s) = -0.5
0.0
next = 5 | V[next] = 0.0
0.0
s = 1 | a = 2 | v = 0.0 | v(s) = -0.75
0.0
next = 0 | V[next] = 0.0
0.0
s = 1 | a = 3 | v = 0.0 | v(s) = -1.0
-1.0

 S = 2
0.0
next = 2 | V[next] = 0.0
0.0
s = 2 | a = 0 | v = 0.0 | v(s) = -0.25
0.0
next = 3 | V[next] = 0.0
0.0
s = 2 | a = 1 | v = 0.0 | v(s) = -0.5
0.0
next = 6 | V[next] = 0.0
0.0
s = 2 | a = 2 | v = 0.0 | v(s) = -0.75
0.0
next = 1 | V[next] = -1.0
0.0
s = 2 | a = 3 | v = 0.0 | v(s) = -1.25
-1.25

 S = 3
0.0
next = 3 | V[next] = 0.0
0.0
s = 3 | a = 0 | v = 0.0 | v(s) = -0.25
0.0
next = 3 | V[next] = 0.

s = 8 | a = 2 | v = -11.121818717405404 | v(s) = -8.834701726401136
-11.121818717405404
next = 8 | V[next] = -11.121818717405404
-11.121818717405404
s = 8 | a = 3 | v = -11.121818717405404 | v(s) = -11.865156405752487
-11.865156405752487

 S = 9
-11.769337388334861
next = 5 | V[next] = -11.055575936115929
-11.769337388334861
s = 9 | a = 0 | v = -11.769337388334861 | v(s) = -3.0138939840289822
-11.769337388334861
next = 10 | V[next] = -11.053818103924186
-11.769337388334861
s = 9 | a = 1 | v = -11.769337388334861 | v(s) = -6.027348510010029
-11.769337388334861
next = 13 | V[next] = -11.8633986227901
-11.769337388334861
s = 9 | a = 2 | v = -11.769337388334861 | v(s) = -9.243198165707554
-11.769337388334861
next = 8 | V[next] = -11.865156405752487
-11.769337388334861
s = 9 | a = 3 | v = -11.769337388334861 | v(s) = -12.459487267145676
-12.459487267145676

 S = 10
-11.053818103924186
next = 6 | V[next] = -12.459487267145676
-11.053818103924186
s = 10 | a = 0 | v = -11.053818103924186 | v(s

s = 10 | a = 3 | v = -14.841254872811456 | v(s) = -15.105986737285104
-15.105986737285104

 S = 11
-11.64164200518016
next = 7 | V[next] = -16.6095476062058
-11.64164200518016
s = 11 | a = 0 | v = -11.64164200518016 | v(s) = -4.40238690155145
-11.64164200518016
next = 11 | V[next] = -11.64164200518016
-11.64164200518016
s = 11 | a = 1 | v = -11.64164200518016 | v(s) = -7.56279740284649
-11.64164200518016
next = 15 | V[next] = 0.0
-11.64164200518016
s = 11 | a = 2 | v = -11.64164200518016 | v(s) = -7.81279740284649
-11.64164200518016
next = 10 | V[next] = -15.105986737285104
-11.64164200518016
s = 11 | a = 3 | v = -11.64164200518016 | v(s) = -11.839294087167765
-11.839294087167765

 S = 12
-17.554218285668938
next = 8 | V[next] = -16.299408065968635
-17.554218285668938
s = 12 | a = 0 | v = -17.554218285668938 | v(s) = -4.324852016492159
-17.554218285668938
next = 13 | V[next] = -16.2994046327411
-17.554218285668938
s = 12 | a = 1 | v = -17.554218285668938 | v(s) = -8.649703174677434
-17

s = 2 | a = 1 | v = -18.316757564769443 | v(s) = -10.11600930748132
-18.316757564769443
next = 6 | V[next] = -18.43999022968743
-18.316757564769443
s = 2 | a = 2 | v = -18.316757564769443 | v(s) = -14.976006864903178
-18.316757564769443
next = 1 | V[next] = -12.927284548076953
-18.316757564769443
s = 2 | a = 3 | v = -18.316757564769443 | v(s) = -18.457828001922415
-18.457828001922415

 S = 3
-20.147279665155843
next = 3 | V[next] = -20.147279665155843
-20.147279665155843
s = 3 | a = 0 | v = -20.147279665155843 | v(s) = -5.286819916288961
-20.147279665155843
next = 3 | V[next] = -20.147279665155843
-20.147279665155843
s = 3 | a = 1 | v = -20.147279665155843 | v(s) = -10.573639832577921
-20.147279665155843
next = 7 | V[next] = -18.457827998569655
-20.147279665155843
s = 3 | a = 2 | v = -20.147279665155843 | v(s) = -15.438096832220335
-20.147279665155843
next = 2 | V[next] = -18.457828001922415
-20.147279665155843
s = 3 | a = 3 | v = -20.147279665155843 | v(s) = -20.302553832700937
-20.30

-19.35732082434815

 S = 3
-21.227904034211885
next = 3 | V[next] = -21.227904034211885
-21.227904034211885
s = 3 | a = 0 | v = -21.227904034211885 | v(s) = -5.556976008552971
-21.227904034211885
next = 3 | V[next] = -21.227904034211885
-21.227904034211885
s = 3 | a = 1 | v = -21.227904034211885 | v(s) = -11.113952017105943
-21.227904034211885
next = 7 | V[next] = -19.357320824344878
-21.227904034211885
s = 3 | a = 2 | v = -21.227904034211885 | v(s) = -16.20328222319216
-21.227904034211885
next = 2 | V[next] = -19.35732082434815
-21.227904034211885
s = 3 | a = 3 | v = -21.227904034211885 | v(s) = -21.2926124292792
-21.2926124292792

 S = 4
-13.512067428470672
next = 0 | V[next] = 0.0
-13.512067428470672
s = 4 | a = 0 | v = -13.512067428470672 | v(s) = -0.25
-13.512067428470672
next = 5 | V[next] = -17.40124271376152
-13.512067428470672
s = 4 | a = 1 | v = -13.512067428470672 | v(s) = -4.85031067844038
-13.512067428470672
next = 8 | V[next] = -19.29853164042582
-13.512067428470672
s = 4

 S = 13
-19.70767261553694
next = 9 | V[next] = -19.729074335952824
-19.70767261553694
s = 13 | a = 0 | v = -19.70767261553694 | v(s) = -5.182268583988206
-19.70767261553694
next = 14 | V[next] = -13.813702354658375
-19.70767261553694
s = 13 | a = 1 | v = -19.70767261553694 | v(s) = -8.8856941726528
-19.70767261553694
next = 13 | V[next] = -19.70767261553694
-19.70767261553694
s = 13 | a = 2 | v = -19.70767261553694 | v(s) = -14.062612326537035
-19.70767261553694
next = 12 | V[next] = -21.678239522634794
-19.70767261553694
s = 13 | a = 3 | v = -19.70767261553694 | v(s) = -19.732172207195735
-19.732172207195735

 S = 14
-13.813702354658375
next = 10 | V[next] = -17.7713883453056
-13.813702354658375
s = 14 | a = 0 | v = -13.813702354658375 | v(s) = -4.6928470863264
-13.813702354658375
next = 15 | V[next] = 0.0
-13.813702354658375
s = 14 | a = 1 | v = -13.813702354658375 | v(s) = -4.9428470863264
-13.813702354658375
next = 14 | V[next] = -13.813702354658375
-13.813702354658375
s = 14 | a 

-21.825644485566542
s = 12 | a = 1 | v = -21.825644485566542 | v(s) = -10.427434759892897
-21.825644485566542
next = 12 | V[next] = -21.825644485566542
-21.825644485566542
s = 12 | a = 2 | v = -21.825644485566542 | v(s) = -16.133845881284532
-21.825644485566542
next = 12 | V[next] = -21.825644485566542
-21.825644485566542
s = 12 | a = 3 | v = -21.825644485566542 | v(s) = -21.840257002676168
-21.840257002676168

 S = 13
-19.854869519785794
next = 9 | V[next] = -19.865494737012952
-19.854869519785794
s = 13 | a = 0 | v = -19.854869519785794 | v(s) = -5.216373684253238
-19.854869519785794
next = 14 | V[next] = -13.907509634169692
-19.854869519785794
s = 13 | a = 1 | v = -19.854869519785794 | v(s) = -8.943251092795661
-19.854869519785794
next = 13 | V[next] = -19.854869519785794
-19.854869519785794
s = 13 | a = 2 | v = -19.854869519785794 | v(s) = -14.15696847274211
-19.854869519785794
next = 12 | V[next] = -21.840257002676168
-19.854869519785794
s = 13 | a = 3 | v = -19.854869519785794 | 

-17.932871901224864
next = 9 | V[next] = -19.927114310265345
-17.932871901224864
s = 10 | a = 3 | v = -17.932871901224864 | v(s) = -17.93849782303798
-17.93849782303798

 S = 11
-13.949881335810616
next = 7 | V[next] = -19.92794771408127
-13.949881335810616
s = 11 | a = 0 | v = -13.949881335810616 | v(s) = -5.2319869285203175
-13.949881335810616
next = 11 | V[next] = -13.949881335810616
-13.949881335810616
s = 11 | a = 1 | v = -13.949881335810616 | v(s) = -8.969457262472972
-13.949881335810616
next = 15 | V[next] = 0.0
-13.949881335810616
s = 11 | a = 2 | v = -13.949881335810616 | v(s) = -9.219457262472972
-13.949881335810616
next = 10 | V[next] = -17.93849782303798
-13.949881335810616
s = 11 | a = 3 | v = -13.949881335810616 | v(s) = -13.954081718232466
-13.954081718232466

 S = 12
-21.905520262580755
next = 8 | V[next] = -19.921356719305827
-21.905520262580755
s = 12 | a = 0 | v = -21.905520262580755 | v(s) = -5.230339179826457
-21.905520262580755
next = 13 | V[next] = -19.9213567193

-17.9694662869402
next = 1 | V[next] = -13.977203154279323
-17.9694662869402
s = 5 | a = 0 | v = -17.9694662869402 | v(s) = -3.7443007885698307
-17.9694662869402
next = 6 | V[next] = -19.966847404036002
-17.9694662869402
s = 5 | a = 1 | v = -17.9694662869402 | v(s) = -8.986012639578831
-17.9694662869402
next = 9 | V[next] = -19.966847404036002
-17.9694662869402
s = 5 | a = 2 | v = -17.9694662869402 | v(s) = -14.227724490587832
-17.9694662869402
next = 4 | V[next] = -13.977203154279325
-17.9694662869402
s = 5 | a = 3 | v = -17.9694662869402 | v(s) = -17.972025279157663
-17.972025279157663

 S = 6
-19.966847404036002
next = 2 | V[next] = -19.967226483936113
-19.966847404036002
s = 6 | a = 0 | v = -19.966847404036002 | v(s) = -5.241806620984028
-19.966847404036002
next = 7 | V[next] = -19.967226483936113
-19.966847404036002
s = 6 | a = 1 | v = -19.966847404036002 | v(s) = -10.483613241968056
-19.966847404036002
next = 10 | V[next] = -17.972025279157663
-19.966847404036002
s = 6 | a = 2 | 

s = 0 | a = 2 | v = 0.0 | v(s) = 0.0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 3 | v = 0.0 | v(s) = 0.0
0.0

 S = 1
-13.98868215109957
next = 1 | V[next] = -13.98868215109957
-13.98868215109957
s = 1 | a = 0 | v = -13.98868215109957 | v(s) = -3.7471705377748923
-13.98868215109957
next = 2 | V[next] = -19.98372907781665
-13.98868215109957
s = 1 | a = 1 | v = -13.98868215109957 | v(s) = -8.993102807229056
-13.98868215109957
next = 5 | V[next] = -17.98611151439964
-13.98868215109957
s = 1 | a = 2 | v = -13.98868215109957 | v(s) = -13.739630685828965
-13.98868215109957
next = 0 | V[next] = 0.0
-13.98868215109957
s = 1 | a = 3 | v = -13.98868215109957 | v(s) = -13.989630685828965
-13.989630685828965

 S = 2
-19.98372907781665
next = 2 | V[next] = -19.98372907781665
-19.98372907781665
s = 2 | a = 0 | v = -19.98372907781665 | v(s) = -5.245932269454163
-19.98372907781665
next = 3 | V[next] = -21.98209083388304
-19.98372907781665
s = 2 | a = 1 | v = -19.98372907781665 | v(s) = -10.9914549779

s = 9 | a = 2 | v = -19.992513446281016 | v(s) = -14.744991122339158
-19.992513446281016
next = 8 | V[next] = -19.99259905049853
-19.992513446281016
s = 9 | a = 3 | v = -19.992513446281016 | v(s) = -19.99314088496379
-19.99314088496379

 S = 10
-17.993682719429053
next = 6 | V[next] = -19.99314088496379
-17.993682719429053
s = 10 | a = 0 | v = -17.993682719429053 | v(s) = -5.248285221240947
-17.993682719429053
next = 11 | V[next] = -13.995283440626167
-17.993682719429053
s = 10 | a = 1 | v = -17.993682719429053 | v(s) = -8.99710608139749
-17.993682719429053
next = 14 | V[next] = -13.995283440626167
-17.993682719429053
s = 10 | a = 2 | v = -17.993682719429053 | v(s) = -12.745926941554032
-17.993682719429053
next = 9 | V[next] = -19.99314088496379
-17.993682719429053
s = 10 | a = 3 | v = -17.993682719429053 | v(s) = -17.99421216279498
-17.99421216279498

 S = 11
-13.995283440626167
next = 7 | V[next] = -19.993219314799138
-13.995283440626167
s = 11 | a = 0 | v = -13.995283440626167 | v(s


 S = 7
-19.996915755100325
next = 3 | V[next] = -21.996605216739944
-19.996915755100325
s = 7 | a = 0 | v = -19.996915755100325 | v(s) = -5.749151304184986
-19.996915755100325
next = 7 | V[next] = -19.996915755100325
-19.996915755100325
s = 7 | a = 1 | v = -19.996915755100325 | v(s) = -10.998380242960067
-19.996915755100325
next = 11 | V[next] = -13.998034438123128
-19.996915755100325
s = 7 | a = 2 | v = -19.996915755100325 | v(s) = -14.74788885249085
-19.996915755100325
next = 6 | V[next] = -19.997141557233636
-19.996915755100325
s = 7 | a = 3 | v = -19.996915755100325 | v(s) = -19.997174241799257
-19.997174241799257

 S = 8
-19.996633623287842
next = 4 | V[next] = -13.997854638025242
-19.996633623287842
s = 8 | a = 0 | v = -19.996633623287842 | v(s) = -3.7494636595063104
-19.996633623287842
next = 9 | V[next] = -19.996880080708657
-19.996633623287842
s = 8 | a = 1 | v = -19.996633623287842 | v(s) = -8.998683679683474
-19.996633623287842
next = 12 | V[next] = -21.99629467837957
-19.9

-21.99831460700933
next = 8 | V[next] = -19.99859710835714
-21.99831460700933
s = 12 | a = 0 | v = -21.99831460700933 | v(s) = -5.249649277089285
-21.99831460700933
next = 13 | V[next] = -19.99859710835714
-21.99831460700933
s = 12 | a = 1 | v = -21.99831460700933 | v(s) = -10.49929855417857
-21.99831460700933
next = 12 | V[next] = -21.99831460700933
-21.99831460700933
s = 12 | a = 2 | v = -21.99831460700933 | v(s) = -16.248877205930903
-21.99831460700933
next = 12 | V[next] = -21.99831460700933
-21.99831460700933
s = 12 | a = 3 | v = -21.99831460700933 | v(s) = -21.998455857683236
-21.998455857683236

 S = 13
-19.99859710835714
next = 9 | V[next] = -19.998699816130376
-19.99859710835714
s = 13 | a = 0 | v = -19.99859710835714 | v(s) = -5.249674954032594
-19.99859710835714
next = 14 | V[next] = -13.999105949618047
-19.99859710835714
s = 13 | a = 1 | v = -19.99859710835714 | v(s) = -8.999451441437106
-19.99859710835714
next = 13 | V[next] = -19.99859710835714
-19.99859710835714
s = 13 |

next = 13 | V[next] = -19.99946436162498
-13.999627416680784
s = 14 | a = 3 | v = -13.999627416680784 | v(s) = -13.999658642421734
-13.999658642421734

 S = 15
0.0
next = 15 | V[next] = 0.0
0.0
s = 15 | a = 0 | v = 0.0 | v(s) = 0.0
0.0
next = 15 | V[next] = 0.0
0.0
s = 15 | a = 1 | v = 0.0 | v(s) = 0.0
0.0
next = 15 | V[next] = 0.0
0.0
s = 15 | a = 2 | v = 0.0 | v(s) = 0.0
0.0
next = 15 | V[next] = 0.0
0.0
s = 15 | a = 3 | v = 0.0 | v(s) = 0.0
0.0

 S = 0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 0 | v = 0.0 | v(s) = 0.0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 1 | v = 0.0 | v(s) = 0.0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 2 | v = 0.0 | v(s) = 0.0
0.0
next = 0 | V[next] = 0.0
0.0
s = 0 | a = 3 | v = 0.0 | v(s) = 0.0
0.0

 S = 1
-13.999593334560018
next = 1 | V[next] = -13.999593334560018
-13.999593334560018
s = 1 | a = 0 | v = -13.999593334560018 | v(s) = -3.7498983336400045
-13.999593334560018
next = 2 | V[next] = -19.999415364016006
-13.999593334560018
s = 1 | a = 1 | v

-21.99968052427042
s = 12 | a = 0 | v = -21.99968052427042 | v(s) = -5.249933518497821
-21.99968052427042
next = 13 | V[next] = -19.999734073991284
-21.99968052427042
s = 12 | a = 1 | v = -21.99968052427042 | v(s) = -10.499867036995642
-21.99968052427042
next = 12 | V[next] = -21.99968052427042
-21.99968052427042
s = 12 | a = 2 | v = -21.99968052427042 | v(s) = -16.249787168063246
-21.99968052427042
next = 12 | V[next] = -21.99968052427042
-21.99968052427042
s = 12 | a = 3 | v = -21.99968052427042 | v(s) = -21.99970729913085
-21.99970729913085

 S = 13
-19.999734073991284
next = 9 | V[next] = -19.99975354282791
-19.999734073991284
s = 13 | a = 0 | v = -19.999734073991284 | v(s) = -5.249938385706978
-19.999734073991284
next = 14 | V[next] = -13.999830527716897
-19.999734073991284
s = 13 | a = 1 | v = -19.999734073991284 | v(s) = -8.999896017636202
-19.999734073991284
next = 13 | V[next] = -19.999734073991284
-19.999734073991284
s = 13 | a = 2 | v = -19.999734073991284 | v(s) = -14.24982

 S = 10
-17.999905405702904
next = 6 | V[next] = -19.999897292330417
-17.999905405702904
s = 10 | a = 0 | v = -17.999905405702904 | v(s) = -5.249974323082604
-17.999905405702904
next = 11 | V[next] = -13.999929374734325
-17.999905405702904
s = 10 | a = 1 | v = -17.999905405702904 | v(s) = -8.999956666766185
-17.999905405702904
next = 14 | V[next] = -13.999929374734323
-17.999905405702904
s = 10 | a = 2 | v = -17.999905405702904 | v(s) = -12.749939010449765
-17.999905405702904
next = 9 | V[next] = -19.999897292330417
-17.999905405702904
s = 10 | a = 3 | v = -17.999905405702904 | v(s) = -17.99991333353237
-17.99991333353237

 S = 11
-13.999929374734325
next = 7 | V[next] = -19.999898466730553
-13.999929374734325
s = 11 | a = 0 | v = -13.999929374734325 | v(s) = -5.249974616682638
-13.999929374734325
next = 11 | V[next] = -13.999929374734325
-13.999929374734325
s = 11 | a = 1 | v = -13.999929374734325 | v(s) = -8.99995696036622
-13.999929374734325
next = 15 | V[next] = 0.0
-13.99992937473

array([  0.        , -13.99993529, -19.99990698, -21.99989761,
       -13.99993529, -17.9999206 , -19.99991379, -19.99991477,
       -19.99990698, -19.99991379, -17.99992725, -13.99994569,
       -21.99989761, -19.99991477, -13.99994569,   0.        ])

In [5]:
import random

In [7]:
def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):
    """
    Algoritmo de mejora de una política. Evalúa iterativamente y mejora una política 
    hasta que encuentra la política óptima.
    
    Args:
        env: ambiente de OpenAI.
        policy_eval_fn: función de evaluación de política que toma 3 argumentos: policy, env, discount_factor
        discount_factor: factor de descuento gama
        
    Retorna:
        Un tuple (policy, V)
        A tuple (policy, V). 
        policy es la política óptima, una matriz de tamaño [S, A] en que cada estado s contiene una distribución de probabilidad 
        valida sobre el espacio de acciones.
        V es la función de valor para la política óptima.
        
    """
    # Comenzar con política aleatoria
    policy = np.ones([env.nS, env.nA]) / env.nA
    
    while True:
        # Implementar
        
        # evaluar (i.e. calcular V, la función de valor de) la política actual
        V = policy_eval_fn(policy, env, discount_factor)
        
        # mantener una variable que indique si la política es estable (no cambió en este paso)
        politica_estable = True
        
        # por cada estado:
        for s in env.P:
            # encontrar cuál es la acción que tomaría la política actual con más alta probabilidad
            if (all(elem == policy[s][0] for elem in policy[s])):
                a_mas_probable = random.randint(0,3)
            else:
                a_mas_probable = policy[s].argmax()
            #print(f"state = {s}, a mas probable = {a_mas_probable}")
            
            # calcular el valor esperado de cada acción utilizando la función de valor actual y haciendo 'one-step look-ahead'
            # encontrar la acción con mayor valor esperado dado del cálculo anterior
            acciones = []        
            for a in env.P[s]:
                retorno_a = 0
                for transicion in env.P[s][a]:
                    p_transicion = transicion[0]
                    proximo_estado = transicion[1]
                    reward = transicion[2]
                    retorno_a += p_transicion*(reward + discount_factor*V[proximo_estado])
                acciones.append((a, retorno_a))


            mejor_a = max(acciones, key=lambda x: x[1])[0]
            
            # si la acción de la política actual no coincide con la mejor calculada
            # actualizar la política
            # marcar que la política no fue estable en este paso
            if (a_mas_probable != mejor_a):
                politica_estable = False
                pol_s = np.array([0,0,0,0])
                pol_s[mejor_a] = 1
                policy[s] = pol_s
   
        # si la política es estable, devolver la política óptima y la función de valor de esa política
        if (politica_estable):
            return policy, V

In [8]:
policy, v = policy_improvement(env)

In [9]:
print("Distribución de probabilidad de la política:")
print(policy)
print("")

print("Política exhibida la grilla: (0=arriba, 1=derecha, 2=abajo, 3=izquierda)")
print(np.reshape(np.argmax(policy, axis=1), env.shape))
print("")

print("Función de valor:")
print(v)
print("")

print("Función de valor exhibida en la grilla:")
print(v.reshape(env.shape))
print("")



Distribución de probabilidad de la política:
[[1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]]

Política exhibida la grilla: (0=arriba, 1=derecha, 2=abajo, 3=izquierda)
[[0 3 3 2]
 [0 0 0 2]
 [0 0 1 2]
 [0 1 1 0]]

Función de valor:
[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.]

Función de valor exhibida en la grilla:
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]



In [10]:
# Test de la función de valor
expected_v = np.array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1,  0])
np.testing.assert_array_almost_equal(v, expected_v, decimal=2)