In [1]:
import numpy as np

In [45]:
# MDP for bandit_walk_five : 
# Deterministic environment (100% action success)
# 3 non-terminal states, 2 terminal states
# The only reward (+1) is at the right-most state
# Episodic environment where agent terminates in the 
# left-most or right-most states T-1-2-3-T
# actions left (0), right (1)
bandit_walk = {
    0: {
        0: [(1.0,0,0.0,True)],
        1: [(1.0,0,0.0,True)]
    },
    1: {
        0: [(1.0,0,0.0,True)],
        1: [(1.0,2,0.0,False)]
    },
    2: {
        0: [(1.0,1,0.0,False)],
        1: [(1.0,3,0.0,False)]
    },
    3: {
        0: [(1.0,2,0.0,False)],
        1: [(1.0,4,1.0,True)]
    },
    4: {
        0: [(1.0,4,0.0,True)],
        1: [(1.0,4,0.0,True)]
    }
}

In [47]:
# MDP for slippery_bandit_walk_five : 
# Stochastic environment (80% action success, 20% backwards)
# 3 non-terminal states, 2 terminal states
# The only reward (+1) is at the right-most state
# Episodic environment where agent terminates in the 
# left-most or right-most states T-1-2-3-T
# actions left (0), right (1)
slippery_bandit_walk = {
    0: {
        0: [(1.0,0,0.0,True)],
        1: [(1.0,0,0.0,True)]
    },
    1: {
        0: [(0.8,0,0.0,True),(0.2,2,0.0,False)],
        1: [(0.8,2,0.0,False),(0.2,0,0.0,True)]
    },
    2: {
        0: [(0.8,1,0.0,False),(0.2,3,0.0,False)],
        1: [(0.8,3,0.0,False),(0.2,1,0.0,False)]
    },
    3: {
        0: [(0.8,2,0.0,False),(0.2,4,1.0,True)],
        1: [(0.8,4,1.0,True),(0.2,2,0.0,False)]
    },
    4: {
        0: [(1.0,4,0.0,True)],
        1: [(1.0,4,0.0,True)]
    }
}

In [48]:
# Random Walk: 
# highly stochastic environment (50% action success, 50% backwards)
# 5 non-terminal states, 2 terminal states
# only reward is still at the right-most cell in the "walk"
# episodic environment, the agent terminates at the left- or right-most cell
# agent starts in state 3 (middle of the walk) T-1-2-3-4-5-T
# actions left (0) or right (1), which don't really make a difference since walk is random
random_walk = {
    0: {
        0: [(1.0, 0, 0.0, True)],
        1: [(1.0, 0, 0.0, True)]
    },
    1: {
        0: [(0.5, 0, 0.0, True), (0.5, 2, 0.0, False)],
        1: [(0.5, 2, 0.0, False), (0.5, 0, 0.0, True)]
    },
    2: {
        0: [(0.5, 1, 0.0, False), (0.5, 3, 0.0, False)],
        1: [(0.5, 3, 0.0, False), (0.5, 1, 0.0, False)]
    },
    3: {
        0: [(0.5, 2, 0.0, False), (0.5, 4, 0.0, False)],
        1: [(0.5, 4, 0.0, False), (0.5, 2, 0.0, False)]
    },
    4: {
        0: [(0.5, 3, 0.0, False), (0.5, 5, 0.0, False)],
        1: [(0.5, 5, 0.0, False), (0.5, 3, 0.0, False)]
    },
    5: {
        0: [(0.5, 4, 0.0, False), (0.5, 6, 1.0, True)],
        1: [(0.5, 6, 1.0, True), (0.5, 4, 0.0, False)]
    },
    6: {
        0: [(1.0, 6, 0.0, True)],
        1: [(1.0, 6, 0.0, True)]
    }
}

In [49]:
print(bandit_walk)

{0: {0: [(1.0, 0, 0.0, True)], 1: [(1.0, 0, 0.0, True)]}, 1: {0: [(1.0, 0, 0.0, True)], 1: [(1.0, 2, 0.0, False)]}, 2: {0: [(1.0, 1, 0.0, False)], 1: [(1.0, 3, 0.0, False)]}, 3: {0: [(1.0, 2, 0.0, False)], 1: [(1.0, 4, 1.0, True)]}, 4: {0: [(1.0, 4, 0.0, True)], 1: [(1.0, 4, 0.0, True)]}}


In [50]:
print(slippery_bandit_walk)

{0: {0: [(1.0, 0, 0.0, True)], 1: [(1.0, 0, 0.0, True)]}, 1: {0: [(0.8, 0, 0.0, True), (0.2, 2, 0.0, False)], 1: [(0.8, 2, 0.0, False), (0.2, 0, 0.0, True)]}, 2: {0: [(0.8, 1, 0.0, False), (0.2, 3, 0.0, False)], 1: [(0.8, 3, 0.0, False), (0.2, 1, 0.0, False)]}, 3: {0: [(0.8, 2, 0.0, False), (0.2, 4, 1.0, True)], 1: [(0.8, 4, 1.0, True), (0.2, 2, 0.0, False)]}, 4: {0: [(1.0, 4, 0.0, True)], 1: [(1.0, 4, 0.0, True)]}}


In [52]:
print(random_walk)

{0: {0: [(1.0, 0, 0.0, True)], 1: [(1.0, 0, 0.0, True)]}, 1: {0: [(0.5, 0, 0.0, True), (0.5, 2, 0.0, False)], 1: [(0.5, 2, 0.0, False), (0.5, 0, 0.0, True)]}, 2: {0: [(0.5, 1, 0.0, False), (0.5, 3, 0.0, False)], 1: [(0.5, 3, 0.0, False), (0.5, 1, 0.0, False)]}, 3: {0: [(0.5, 2, 0.0, False), (0.5, 4, 0.0, False)], 1: [(0.5, 4, 0.0, False), (0.5, 2, 0.0, False)]}, 4: {0: [(0.5, 3, 0.0, False), (0.5, 5, 0.0, False)], 1: [(0.5, 5, 0.0, False), (0.5, 3, 0.0, False)]}, 5: {0: [(0.5, 4, 0.0, False), (0.5, 6, 1.0, True)], 1: [(0.5, 6, 1.0, True), (0.5, 4, 0.0, False)]}, 6: {0: [(1.0, 6, 0.0, True)], 1: [(1.0, 6, 0.0, True)]}}


In [53]:
def policy_evaluation(pi, P, gamma=1.0, theta=1e-10):
    state_space = len(P)
    prev_state_values = np.zeros(state_space)
    
    while True:
        current_state_values = np.zeros(state_space)
        for state in range(state_space):
            for prob, next_state, reward, done in P[state][pi[state]]:
                current_state_values[state] += prob * (reward + gamma * prev_state_values[next_state] * (not done))
        if np.max(np.abs(prev_state_values - current_state_values)) < theta:
            break
        prev_state_values = current_state_values.copy()
    return current_state_values
    

In [87]:
policy_always_right = {0:1,1:1,2:1,3:1,4:1}
policy_always_left = {0:0,1:0,2:0,3:0,4:0}

In [88]:
state_values_always_right = policy_evaluation(policy_always_right, bandit_walk, gamma=0.99)
state_values_always_left = policy_evaluation(policy_always_left, bandit_walk, gamma=0.99)

In [89]:
print(f"ALWAYS RIGHT : {state_values_always_right}")
print(f"ALWAYS LEFT : {state_values_always_left}")

ALWAYS RIGHT : [0.     0.9801 0.99   1.     0.    ]
ALWAYS LEFT : [0. 0. 0. 0. 0.]


In [90]:
state_values_always_right = policy_evaluation(policy_always_right, slippery_bandit_walk, gamma=0.99)
state_values_always_left = policy_evaluation(policy_always_left, slippery_bandit_walk, gamma=0.99)

In [91]:
print(f"ALWAYS RIGHT : {state_values_always_right}")
print(f"ALWAYS LEFT : {state_values_always_left}")

ALWAYS RIGHT : [0.         0.73111101 0.92311996 0.98277775 0.        ]
ALWAYS LEFT : [0.         0.01142361 0.057695   0.24569444 0.        ]


In [92]:
def policy_improvement(V,P, gamma=1.0):
    state_space = len(P)
    action_space = len(P[0])
    
    Q = np.zeros((state_space, action_space))
    
    for state in range(state_space):
        for action in range(action_space):
            for prob, next_state, reward, done in P[state][action]:
                Q[state][action] = prob * (reward + gamma * V[next_state] * (not done))
    
    new_pi = {s:a for s,a in enumerate(np.argmax(Q, axis=1))}
    return new_pi

In [93]:
improved_policy_always_right = policy_improvement(state_values_always_left,slippery_bandit_walk, gamma=0.99)

In [94]:
print(improved_policy_always_right)

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0}


In [95]:
def policy_iteration(P, gamma=1.0, theta=1e-10):
    state_space = len(P)
    action_space = len(P[0])
    
    random_actions = [np.random.choice(action_space) for _ in range(state_space)]
    pi = {s:a for s,a in enumerate(random_actions)}
    print(f"RANDOM POLICY : {pi}")
    
    while True:
        old_pi = pi.copy()
        V = policy_evaluation(pi, P, gamma, theta)
        pi = policy_improvement(V,P,gamma)
        if old_pi == pi:
            break
            
    return V, pi

In [96]:
state_values, policy = policy_iteration(bandit_walk, gamma=0.99)

RANDOM POLICY : {0: 0, 1: 0, 2: 1, 3: 1, 4: 0}


In [97]:
print(f"STATE VALUES : {state_values}, POLICY : {policy}")

STATE VALUES : [0.     0.9801 0.99   1.     0.    ], POLICY : {0: 0, 1: 1, 2: 1, 3: 1, 4: 0}


In [107]:
def value_iteration(P, gamma=1.0, theta=1e-10):
    state_space = len(P)
    action_space = len(P[0])
    
    V = np.zeros(state_space, dtype=np.float64)
    while True:
        Q = np.zeros((state_space, action_space), dtype=np.float64)
        for state in range(state_space):
            for action in range(action_space):
                for prob, next_state, reward, done in P[state][action]:
                    Q[state][action] += prob * (reward + gamma * V[next_state] * (not done))
        if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
            break
            
        V = np.max(Q, axis=1)
        pi = {s:a for s,a in enumerate(np.argmax(Q, axis=1))}
        
        return V,pi

In [110]:
state_value, policy = value_iteration(bandit_walk, gamma=0.99)
print(f"STATE VALUES : {state_values}, POLICY : {policy}")

STATE VALUES : [0.     0.9801 0.99   1.     0.    ], POLICY : {0: 0, 1: 0, 2: 0, 3: 1, 4: 0}
