In [41]:
import numpy as np

num_of_states = 3
gamma = 0.9

def create_distribution(n, temp=1.0):
    logits = np.random.randn(n)
    exp_logits = np.exp(logits / temp)
    prob = exp_logits / exp_logits.sum() 
    return prob


def create_MRP():
    P = np.zeros((num_of_states, num_of_states))
    for s in range(num_of_states):
        P[s, :] = create_distribution(num_of_states)
    r = np.random.randn(num_of_states, 1)
    return P, r

def policy_evaluation(P, r, gamma):
    bellman_operator = lambda v: r + gamma * P @ v
    v = np.random.randn(P.shape[0], 1)  # Initialize value function
    error = float('inf')
    while error > 1e-5:
        v_prev = v
        # Update value function for state s
        v = bellman_operator(v_prev)
        error = np.mean(np.abs(v-v_prev))
        #print(error)
    return v


In [42]:
num_of_actions = 2

def create_MDP():
    P = np.zeros((num_of_states, num_of_actions, num_of_states))
    for s in range(num_of_states):
        for a in range(num_of_actions):
            P[s, a, :] = create_distribution(num_of_states)
    r = np.random.randn(num_of_states, num_of_actions)
    return P, r

def initial_policy():
    pi = np.zeros((num_of_states, num_of_actions))
    for s in range(num_of_states):
        pi[s, :] = create_distribution(num_of_actions)
    return pi

def policy_iteration(P, r, gamma, tol=1e-5):
    pi = initial_policy()  # Initialize policy
    r = np.reshape(r, (num_of_states * num_of_actions, 1))
    q = np.zeros((num_of_states, num_of_actions))  # Initialize Q-values
    q_prev = np.copy(q)  # Previous Q-values for convergence checking

    while True:
        # next 1: Policy Evaluation
        P_pi = np.zeros((num_of_states, num_of_actions, num_of_states, num_of_actions))
        
        # next 2: Policy Improvement
        for s in range(num_of_states):
            for a in range(num_of_actions):
                for s_prime in range(num_of_states):
                    for a_prime in range(num_of_actions):
                        P_pi[s, a, s_prime, a_prime] = P[s, a, s_prime] * pi[s_prime, a_prime]
        
        P_pi = np.reshape(P_pi, (num_of_states * num_of_actions, num_of_states * num_of_actions))
        q = policy_evaluation(P_pi, r, gamma)
        print(q)
        q = np.reshape(q, (num_of_states, num_of_actions))
        
        # Check for convergence: if the difference between q and q_prev is less than tolerance
        if np.max(np.abs(q - q_prev)) < tol:
            break  # Convergence achieved
        
        q_prev = np.copy(q)
        
        # Policy improvement: Update policy based on the new Q-values
        pi = np.zeros((num_of_states, num_of_actions))
        for s in range(num_of_states):
            pi[s, np.argmax(q[s, :])] = 1



In [43]:
P, r = create_MDP()
policy_iteration(P, r, gamma)


[[-2.69117907]
 [-1.25173211]
 [-1.94673296]
 [-2.08200641]
 [-2.4079183 ]
 [-4.29199762]]
[[3.71861577]
 [5.10363459]
 [4.76977167]
 [4.55449481]
 [4.03144054]
 [2.28712057]]
[[3.71861358]
 [5.10363239]
 [4.76976948]
 [4.55449262]
 [4.03143835]
 [2.28711838]]
