In [1]:
import random


class PolicyIteration:
  def __init__(self, states, actions, reward_func=None):
    self.states = states
    self.actions = actions
    self.reward_func = lambda x: x == 1 and 1 or -1 # sample function
    self.values = dict(zip(self.states, [0] * len(self.states))) # all zeroes
    self.transitional_model = dict(zip(self.states, 
                                       [dict(zip(self.actions, [random.random() for a in self.actions])) for s in self.states])) # random transitions
    self.discount = 0.9 # usually denoted as gamma
    self.convergence_threshold = 0.1 # usually denoted as theta
    self.policy = dict(zip(self.states, random.choices(self.actions, k=4))) # random actions
    
  def evaluate_policy(self):
    delta = 0
    for state in self.states:
      action = self.policy[state]
      prob = self.transitional_model[state][action]
      value = self.values[state]
      all_values = sum(self.values)
      new_value = self.reward_func(state) + self.discount * prob * all_values
      self.values[state] = new_value
      delta = max(delta, abs(value - new_value))
    return delta
  
  def improve_policy(self):
    policy_changed = False
    for state in self.states:
      old_action = self.policy[state]
      action_value = [0] * len(self.actions)
      for action in self.actions:
        prob = self.transitional_model[state][action]
        all_values = sum(self.values)
        action_value[action] = prob * all_values
      new_action = action_value.index(max(action_value))
      self.policy[state] = new_action
      policy_changed = policy_changed or (new_action != old_action)
    return policy_changed
  
  def train(self, max_epoch=500):
    epoch = 0
    while epoch < max_epoch:
      epoch += 1
      self.evaluate_policy()
      changed = self.improve_policy()
      if not changed:
        break

In [2]:
solver = PolicyIteration(states=[1,2,3,4], actions=[1,2])

In [3]:
solver.states

[1, 2, 3, 4]

In [4]:
solver.actions

[1, 2]

In [5]:
solver.values

{1: 0, 2: 0, 3: 0, 4: 0}

In [17]:
solver.transitional_model[1][1]

0.8715986035960382

In [8]:
solver.policy

{1: 1, 2: 2, 3: 1, 4: 1}

In [11]:
solver.transitional_model = dict(zip(solver.states, [dict(zip(solver.actions, [random.random() for a in solver.actions])) for s in solver.states]))

In [12]:
solver.evaluate_policy()

7.155220681388876

In [15]:
solver.values

{1: 8.844387432364345,
 2: 3.292557663026728,
 3: 3.263194900964023,
 4: 4.10457255598351}

In [18]:
solver.improve_policy()

IndexError: list assignment index out of range

In [24]:
for state in solver.states:
    for action in solver.actions:
        print(state, action, solver.transitional_model[state][action])

1 1 0.8715986035960382
1 2 0.14836439865606244
2 1 0.7438542360077686
2 2 0.4769508514474142
3 1 0.4736883223293359
3 2 0.36134471794816714
4 1 0.5671747284426122
4 2 0.2892776979191115


In [25]:
solver.policy

{1: 1, 2: 2, 3: 1, 4: 1}

In [32]:
def improve_policy_new(self):
    policy_changed = False
    for state in self.states:
      old_action = self.policy[state]
      action_value = dict(zip(self.actions, [0] * len(self.actions)))
      for action in self.actions:
        prob = self.transitional_model[state][action]
        all_values = sum(self.values)
        action_value[action] = prob * all_values
      new_action = max(action_value, key=action_value.get)
      self.policy[state] = new_action
      policy_changed = policy_changed or (new_action != old_action)
    return policy_changed

PolicyIteration.improve_policy = improve_policy_new
solver.improve_policy()

True

In [35]:
solver.policy = dict(zip(solver.states, random.choices(solver.actions, k=4)))
solver.policy

{1: 1, 2: 1, 3: 2, 4: 2}