In [16]:
import numpy as np

# TODO
## transition prob matrix (s,a,s')
## reward matrix (s,a)

def env(s, a):
    '''
    args: state dim, action dim

    return: transition prob, reward matrix
    '''
    
    p = np.random.random(s*a*s).reshape(s,a,s)
    p = p / p.sum(axis=2, keepdims=True) 
    r = np.random.random(s*a).reshape(s,a)
    
    return p, r

def random_policy(s,a):
    '''
    args: state dim, action dim

    return: random policy
    '''
    pi = np.random.random(s*a).reshape(s,a)
    pi = pi / pi.sum(axis=1, keepdims=True)
    return pi


In [17]:
# state, action, gamma
s= 10
a = 10
gamma = 0.9

# transition prob, reward matrix
P,R = env(s,a)

# random policy
pi = random_policy(s,a)

In [18]:
# linear equation (evaluation)
## v = (I - gamma*P)^-1 * r

# policy reward, policy transition prob
r_pi = np.sum(pi*R, axis=1)
p_pi = np.sum(pi[:,:,np.newaxis] * P, axis=1)

print(p_pi.shape)
print(r_pi.shape)

V = np.linalg.inv(np.eye(s) - gamma*p_pi) @ r_pi

(10, 10)
(10,)


In [19]:
print(V.shape)
print(V)

(10,)
[5.0012556  5.15941529 5.37829811 5.12676508 5.10051395 5.25804683
 5.08696993 5.22286406 4.98872513 5.12514297]


In [22]:
def policy_evaluation_matrix(p, r, policy, gamma=0.9, theta=1e-6):
    """
    Args:
        p: transition probability matrix (S x A x S)
        r: reward matrix (S x A)
        policy: policy matrix (S x A)
        gamma: discount factor
        theta: threshold
        
    Returns:
        V: state-value function (S)
    """
    n_states = p.shape[0]
    
    # Initialize value function
    V = np.zeros(n_states)
    
    while True:
        # Compute expected rewards for each state (S x 1)
        expected_rewards = np.sum(policy * r, axis=1)
        
        # Compute expected next state values (S x S)
        # p_policy: (S x S) matrix where p_policy[s,s'] = sum_a policy(s,a) * p(s,a,s')
        p_policy = np.sum(policy[:, :, np.newaxis] * p, axis=1)
        
        # Compute new value function
        V_new = expected_rewards + gamma * np.dot(p_policy, V)
        
        # Check for convergence
        if np.max(np.abs(V_new - V)) < theta:
            break
            
        V = V_new
        
    return V

policy_evaluation_matrix(P, R, pi)

array([5.00124579, 5.15940548, 5.3782883 , 5.12675526, 5.10050414,
       5.25803702, 5.08696012, 5.22285425, 4.98871532, 5.12513315])

In [23]:
def policy_improvement_matrix(p, r, V, gamma=0.9):
    """
    policy improvement using 
    
    Args:
        p: transition probability matrix (S x A x S)
        r: reward matrix (S x A)
        V: current value function (S)
        gamma: discount factor
        
    Returns:
        new_policy: improved policy matrix (S x A)
    """
    # Q(s,a) = R(s,a) + gamma * sum_s' P(s'|s,a) * V(s')
    Q = r + gamma * np.sum(p * V[np.newaxis, np.newaxis, :], axis=2)

    new_policy = np.zeros_like(Q)
    best_actions = np.argmax(Q, axis=1)
    new_policy[np.arange(len(Q)), best_actions] = 1
    
    return new_policy


policy_improvement_matrix(P, R, V)

array([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

In [26]:
# policy iteration
def policy_iteration_matrix(p, r, policy, gamma=0.9, theta=1e-6):
    """
    policy iteration 
    
    Args:
        p: transition probability matrix (S x A x S)
        r: reward matrix (S x A)
        gamma: discount factor
        theta: threshold
        
    Returns:
        new_policy: improved policy matrix (S x A)
    """
    while True:
        V = policy_evaluation_matrix(p, r, policy, gamma, theta)
        new_policy = policy_improvement_matrix(p, r, V, gamma)
        
        if new_policy.all() == policy.all():
            break
        
        policy = new_policy
        
    return policy, V

policy_iteration_matrix(P, R, pi)



(array([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]]),
 array([9.02434363, 8.99806546, 8.98017772, 8.85479739, 8.96444362,
        9.01732508, 8.96768506, 8.96206667, 8.68237136, 8.90715751]))

In [25]:
def value_iteration_matrix(p, r, gamma=0.9, theta=1e-6):
    """
    value iteration 
    
    Args:
        p: transition probability matrix (S x A x S)
        r: reward matrix (S x A)
        gamma: discount factor
        theta: convergence threshold
        
    Returns:
        V: optimal state-value function (S)
        policy: optimal policy (S x A)
    """
    n_states = p.shape[0]
    n_actions = p.shape[1]
    
    # Initialize value function
    V = np.zeros(n_states)
    
    while True:
        # Q-values for all state-action pairs
        ## Q(s,a) = R(s,a) + gamma * sum_s' P(s'|s,a) * V(s')
        Q = r + gamma * np.sum(p * V[np.newaxis, np.newaxis, :], axis=2)
        
        # new value function 
        V_new = np.max(Q, axis=1)
        
        # convergence
        if np.max(np.abs(V_new - V)) < theta:
            break
            
        V = V_new
    
    policy = np.zeros((n_states, n_actions))
    best_actions = np.argmax(Q, axis=1)
    policy[np.arange(n_states), best_actions] = 1
    
    return V, policy

value_iteration_vectorized(P, R)

(array([9.03717282, 9.01160563, 8.99217841, 8.86719043, 8.99230531,
        9.03093301, 8.98070783, 8.97450407, 8.69424636, 8.92052505]),
 array([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]]))