In [26]:
import gym
import numpy as np

class DynamicProgramming:
    def __init__(self) -> None:
        self.env = gym.make('FrozenLake-v1')
        self.nS = 16
        self.nA = 4
        self.gamma = 1
        self.theta = 1e-4
        self.random_policy = np.ones((self.nS, self.nA)) / self.nA

    def policy_evaluation(self, policy):
        prev_V = np.zeros(self.nS)
        iteration = 0
        while True:
            delta = 0
            iteration += 1
            V = np.zeros(self.nS)
            for state in range(self.nS):
                for action in range(self.nA):
                    for prob, next_state, reward, _ in self.env.P[state][action]:
                        V[state] += policy[state][action] * prob * (reward + self.gamma * prev_V[next_state])
            delta = np.max(np.abs(V - prev_V))
            if delta <= self.theta:
                break
            prev_V = np.copy(V)
        return V, iteration
    
    def policy_improvement(self, V):
        policy = np.zeros((self.nS, self.nA))
        q = np.zeros((self.nS, self.nA))
        for state in range(self.nA):
            for action in range(self.nA):
                for prob, next_state, reward, _ in self.env.P[state][action]:
                    q[state][action] += prob * (reward + self.gamma * V[next_state])

        max_index = np.argmax(q, axis=1)

        for i, optimal_action in enumerate(max_index):
            policy[i][optimal_action] = 1

        return policy

    def policy_iteration(self, policy=None):
        policy = self.random_policy if policy == None else policy
        old_policy = policy
        iteration = 0
        while True:
            V, _ = self.policy_evaluation(policy)
            policy = self.policy_improvement(V)
            iteration += 1
            comparison = policy == old_policy
            if comparison.all() == True:
                break
            old_policy = policy
        return policy, iteration
    
    def value_iteration(self):
        policy = np.zeros((self.nS, self.nA))
        prev_V = np.zeros((self.nS, 1))
        iteration = 0
        while True:
            iteration += 1
            Q = np.zeros((self.nS, self.nA))
            for state in range(self.nS):
                for action in range(self.nA):
                    for prob, next_state, reward, _ in self.env.P[state][action]:
                        Q[state][action] += prob * (reward + self.gamma * prev_V[next_state])
            V = np.max(Q, axis=1)
            if np.max(np.abs(prev_V - V)) < self.theta:
                break
            prev_V = V
            max_index = np.argmax(Q, axis=1)
            for i, optimal_action in enumerate(max_index):
                policy[i][optimal_action]  = 1
        return policy, iteration
    
    
DP = DynamicProgramming()
optimal_policy , iteration = DP.policy_iteration()
optimal_policy2 , iteration2 = DP.value_iteration()

if np.all(optimal_policy == optimal_policy2):
    print("same policy")

print(f"{optimal_policy}\n{iteration}")
print(f"{optimal_policy2}\n{iteration2}")

[[1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]]
3
[[1. 1. 1. 0.]
 [1. 1. 1. 1.]
 [1. 0. 1. 1.]
 [1. 0. 0. 1.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 1. 0. 1.]
 [1. 1. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 1. 1. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]]
243


  Q[state][action] += prob * (reward + self.gamma * prev_V[next_state])


In [75]:
def decay_schedule(
 init_value, min_value,
 decay_ratio, max_steps,
 log_start=-2, log_base=10):
 decay_steps = int(max_steps * decay_ratio)
 rem_steps = max_steps - decay_steps
 values = np.logspace(
 log_start, 0, decay_steps,
 base=log_base, endpoint=True)[::-1]
 values = (values - values.min()) / \
 (values.max() - values.min())
 values = (init_value - min_value) * values + min_value
 values = np.pad(values, (0, rem_steps), 'edge')
 return values

In [85]:
def sarsa(env, gamma=1.0, init_alpha=0.5, min_alpha=0.01, alpha_decay_ratio=0.5, init_epsilon=1.0, min_epsilon=0.1, epsilon_decay_ratio=0.9, n_episodes=3000):
    nS, nA = env.observation_space.n, env.action_space.n
    pi_track = []
    Q = np.zeros((nS, nA), dtype=np.float64)
    Q_track = np.zeros((n_episodes, nS, nA), dtype=np.float64)
    select_action = lambda state, Q, epsilon: np.argmax(Q[state]) if np.random.random() > epsilon else np.random.randint(len(Q[state]))
    alphas = decay_schedule(
    init_alpha, min_alpha,
    alpha_decay_ratio,
    n_episodes)
    epsilons = decay_schedule(
    init_epsilon, min_epsilon,
    epsilon_decay_ratio,
    n_episodes)
    for e in range(n_episodes):
        state, info = env.reset()
        done = False
        action = select_action(state, Q, epsilons[e])
        while not done:
            next_state, reward, done, _, _ = env.step(action)
            next_action = select_action(next_state,
            Q,
            epsilons[e])
            td_target = reward + gamma * \
            Q[next_state][next_action] * (not done)
            td_error = td_target - Q[state][action]
            Q[state][action] = Q[state][action] + \
            alphas[e] * td_error
            state, action = next_state, next_action
        Q_track[e] = Q
        pi_track.append(np.argmax(Q, axis=1))
    V = np.max(Q, axis=1)
    pi = lambda s: {s:a for s, a in enumerate(\
    np.argmax(Q, axis=1))}[s]
    return Q, V, pi, Q_track, pi_track

sarsa(gym.make('CliffWalking-v0'))

  if not isinstance(terminated, (bool, np.bool8)):


(array([[-1.33125164e+03, -5.42399958e+02, -1.51909211e+03,
         -1.35619042e+03],
        [-1.14934357e+03, -4.60742634e+02, -1.61430928e+03,
         -1.32770762e+03],
        [-1.03507890e+03, -3.78236137e+02, -1.56724335e+03,
         -1.15375700e+03],
        [-8.28530998e+02, -2.90447546e+02, -1.33961459e+03,
         -1.09530415e+03],
        [-8.03368779e+02, -2.35243408e+02, -1.31340167e+03,
         -9.22280605e+02],
        [-6.60978643e+02, -2.14688166e+02, -1.07069522e+03,
         -8.08832837e+02],
        [-4.96563327e+02, -1.60080859e+02, -1.02980839e+03,
         -6.09476371e+02],
        [-4.13914980e+02, -1.13158169e+02, -8.58899174e+02,
         -5.37686646e+02],
        [-2.99244312e+02, -7.20468471e+01, -9.20358069e+02,
         -4.36284927e+02],
        [-2.88674659e+02, -5.36293930e+01, -4.82659683e+02,
         -4.19768422e+02],
        [-1.71872047e+02, -3.75050538e+01, -2.81679656e+02,
         -2.18327594e+02],
        [-1.31475418e+02, -1.37404857e+02, 