<a href="https://colab.research.google.com/github/obedotto/rl-value-iteration/blob/main/RL_Value_Iteration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
np.bool = bool  # Fix numpy compatibility issue


In [None]:
import warnings ; warnings.filterwarnings('ignore')

import gym, gym_walk
import numpy as np

import random
import warnings

warnings.filterwarnings('ignore', category=DeprecationWarning)
np.set_printoptions(suppress=True)
random.seed(123); np.random.seed(123)


In [None]:
pip install git+https://github.com/mimoralea/gym-walk#egg=gym-walk

Collecting gym-walk
  Cloning https://github.com/mimoralea/gym-walk to /tmp/pip-install-k2cwrutz/gym-walk_8dfaeb6d356a47a5bcab4d17704d44b8
  Running command git clone --filter=blob:none --quiet https://github.com/mimoralea/gym-walk /tmp/pip-install-k2cwrutz/gym-walk_8dfaeb6d356a47a5bcab4d17704d44b8
  Resolved https://github.com/mimoralea/gym-walk to commit b915b94cf2ad16f8833a1ad92ea94e88159279f5
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
def print_policy(pi, P, action_symbols=('<', 'v', '>', '^'), n_cols=4, title='Policy:'):
    print(title)
    arrs = {k:v for k,v in enumerate(action_symbols)}
    for s in range(len(P)):
        a = pi[s]
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(2), arrs[a].rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [None]:
def print_state_value_function(V, P, n_cols=4, prec=3, title='State-value function:'):
    print(title)
    for s in range(len(P)):
        v = V[s]
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(2), '{}'.format(np.round(v, prec)).rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [None]:
def probability_success(env, pi, goal_state, n_episodes=100, max_steps=200):
    random.seed(123)
    np.random.seed(123)
    env.reset(seed=123)
    results = []
    for _ in range(n_episodes):
        state, done, steps = env.reset(), False, 0
        while not done and steps < max_steps:
            state, _, done, _ = env.step(pi[state])  # FIXED: pi(state) → pi[state]
            steps += 1
        results.append(state == goal_state)
    return np.sum(results) / len(results)

In [None]:
def mean_return(env, pi, n_episodes=100, max_steps=200):
    random.seed(123)
    np.random.seed(123)
    env.reset(seed=123)
    results = []
    for _ in range(n_episodes):
        state, done, steps = env.reset(), False, 0
        results.append(0.0)
        while not done and steps < max_steps:
            state, reward, done, _ = env.step(pi[state])  # FIXED: pi(state) → pi[state]
            results[-1] += reward
            steps += 1
    return np.mean(results)

# Creating the Frozen Lake environment

In [None]:
envdesc  = ['SFHH','HFFH','HHFH', 'HGFH']
env = gym.make('FrozenLake-v1',desc=envdesc)
init_state = env.reset()
goal_state = 13
P = env.env.P

In [None]:
P

{0: {0: [(0.3333333333333333, 0, 0.0, False),
   (0.3333333333333333, 0, 0.0, False),
   (0.3333333333333333, 4, 0.0, True)],
  1: [(0.3333333333333333, 0, 0.0, False),
   (0.3333333333333333, 4, 0.0, True),
   (0.3333333333333333, 1, 0.0, False)],
  2: [(0.3333333333333333, 4, 0.0, True),
   (0.3333333333333333, 1, 0.0, False),
   (0.3333333333333333, 0, 0.0, False)],
  3: [(0.3333333333333333, 1, 0.0, False),
   (0.3333333333333333, 0, 0.0, False),
   (0.3333333333333333, 0, 0.0, False)]},
 1: {0: [(0.3333333333333333, 1, 0.0, False),
   (0.3333333333333333, 0, 0.0, False),
   (0.3333333333333333, 5, 0.0, False)],
  1: [(0.3333333333333333, 0, 0.0, False),
   (0.3333333333333333, 5, 0.0, False),
   (0.3333333333333333, 2, 0.0, True)],
  2: [(0.3333333333333333, 5, 0.0, False),
   (0.3333333333333333, 2, 0.0, True),
   (0.3333333333333333, 1, 0.0, False)],
  3: [(0.3333333333333333, 2, 0.0, True),
   (0.3333333333333333, 1, 0.0, False),
   (0.3333333333333333, 0, 0.0, False)]},
 2: {0

# Value Iteration Algorithm

In [None]:
def value_iteration(P, gamma=1.0, theta=1e-10):
    V = np.zeros(len(P), dtype=np.float64)   # Initialize state-value function
    pi = np.zeros(len(P), dtype=int)         # Initialize policy (best action per state)

    while True:
        delta = 0
        for s in range(len(P)):  # Loop over all states
            v = V[s]
            q_values = []
            for a in P[s]:  # Loop over possible actions
                q_sa = 0
                for prob, next_s, reward, done in P[s][a]:
                    q_sa += prob * (reward + gamma * V[next_s])
                q_values.append(q_sa)

            V[s] = max(q_values)              # Update state-value
            pi[s] = np.argmax(q_values)       # Update optimal action
            delta = max(delta, abs(v - V[s]))  # Track convergence

        if delta < theta:  # Stop when values converge sufficiently
            break

    return V, pi


In [None]:
# Finding the optimal policy
V_best_v, pi_best_v = value_iteration(P, gamma=0.99)


In [None]:
# Printing the policy
print("Name: DHARMARAJ S Register Number: 212222240025")
print('Optimal policy and state-value function (VI):')
print_policy(pi_best_v, P)

Name: DHARMARAJ S Register Number: 212222240025
Optimal policy and state-value function (VI):
Policy:
| 00      ^ | 01      < |           |           |
|           | 05      > | 06      < |           |
|           |           | 10      < |           |
|           |           | 14      < |           |


In [None]:
# printing the success rate and the mean return
print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(
    probability_success(env, pi_best_v, goal_state=goal_state)*100,
    mean_return(env, pi_best_v)))

Reaches goal 5.00%. Obtains an average undiscounted return of 0.0500.


In [None]:
# printing the state value function
print_state_value_function(V_best_v, P, prec=4)

State-value function:
| 00 0.0399 | 01 0.0411 |           |           |
|           | 05 0.0436 | 06 0.0909 |           |
|           |           | 10 0.2319 |           |
|           |           | 14 0.6117 |           |
