In [50]:
import numpy as np
import typing
import abc
import copy

In [51]:
# softmax along last dimension
def softmax(arr: np.ndarray) -> np.ndarray:
  exp = np.exp(arr)
  norm = np.sum(exp, axis=-1)
  norm = np.reshape(norm, (*arr.shape[0:-1], 1))
  norm = np.repeat(norm, arr.shape[-1], axis=-1)
  return exp / norm

In [52]:
State = int
Action = int

class MDP:
  def __init__(
    self,
    states: list[State],
    actions: list[Action],
    discount_factor: float,
    init_dist: np.ndarray,
    transition_dist: np.ndarray, # [S, A, S']
    rewards: np.ndarray, # [S, A, S']
  ):
    n_states = len(states)
    n_actions = len(actions)

    # check the dimensionality of the numpy arrays and that probs are valid
    assert init_dist.shape == (n_states,)
    assert init_dist.sum() == 1.0
    assert transition_dist.shape == (n_states, n_actions, n_states)
    assert np.isclose(transition_dist.sum(axis=-1), np.ones((n_states, n_actions))).all()
    assert rewards.shape == (n_states, n_actions, n_states)

    self.states = states
    self.n_states = n_states
    self.actions = actions
    self.n_actions = n_actions
    self.discount_factor = discount_factor
    self.init_dist = init_dist
    self.transition_dist = transition_dist
    self.rewards = rewards

    # start with zero Q-values and a uniform policy
    self.q = np.zeros((self.n_states, self.n_actions))
    self.policy = np.ones((self.n_states, self.n_actions)) / self.n_actions

  def q_learn(
    self,
    max_iters=10000,
    epsilon=0.1,
    episode_len=100,
    learning_rate=1e-3
  ):
    for i in range(max_iters):
      # reset the episode every now and then so it doesn't get stuck
      if i % episode_len == 0:
        s = np.random.choice(self.states, p=self.init_dist)

      # behavior policy = epsilon-greedy
      if np.random.random() > epsilon:
        a = self.q[s].argmax()
      else:
        a = np.random.choice(self.actions)

      # sample next state
      s_next = np.random.choice(self.states, p=self.transition_dist[s, a])
      r = self.rewards[s, a, s_next]

      # compute TD error and update Q value
      delta = r + self.discount_factor * self.q[s_next].max() - self.q[s, a]
      self.q[s, a] += learning_rate * delta

      # the next state becomes the current state
      s = s_next

    # derive the policy by softmaxing along the actions dimension
    self.policy = softmax(self.q)
    

In [53]:
# testing MDP class
mdp = MDP([0, 1, 2], [0, 1], 0.9,
  init_dist=np.array([0.8, 0.1, 0.1]),
  transition_dist=np.array([
    [ # s=0
      [0.1, 0.8, 0.1], # a=0
      [0.1, 0.1, 0.8], # a=1
    ],
    [ # s=1
      [0.1, 0.1, 0.8], # a=0
      [0.8, 0.1, 0.1], # a=1
    ],
    [ # s=2
      [0.8, 0.1, 0.1], # a=0
      [0.1, 0.8, 0.1], # a=1
    ],
  ]),
  rewards=np.array([
    [ # s=0
      [0, 5, -2], # a=0
      [0, 5, -2], # a=1
    ],
    [ # s=1
      [0, 1, -2], # a=0
      [0, 1, -2], # a=1
    ],
    [ # s=2
      [0, 30, 0], # a=0
      [0, 30, 0], # a=1
    ],
  ])
)
mdp.q_learn()
mdp.policy

array([[0.99891739, 0.00108261],
       [0.01010049, 0.98989951],
       [0.96327667, 0.03672333]])

In [55]:
n_s = 10
states = [i for i in range(n_s)]
n_a = 4
actions = [i for i in range(n_a)]
r1 = np.random.random((n_s, n_a, n_s))
r2 = np.random.random((n_s, n_a, n_s))
transitions = softmax(np.random.random((n_s, n_a, n_s)))
init_dist = softmax(np.random.random(n_s))

mdp1 = MDP(states, actions, 0.9, init_dist, transitions, r1)
mdp1.q_learn()
mdp2 = MDP(states, actions, 0.9, init_dist, transitions, r2)
mdp2.q_learn()
print(mdp1.policy)
print(mdp2.policy)

# TODO: compare regrets with rollouts

[[0.32697404 0.22533149 0.22460465 0.22308982]
 [0.35319649 0.21521136 0.21594674 0.21564541]
 [0.34755482 0.2176787  0.2174106  0.21735588]
 [0.34887035 0.21880836 0.21553474 0.21678655]
 [0.3372785  0.21883396 0.22279749 0.22109005]
 [0.33701409 0.22135978 0.22097848 0.22064765]
 [0.32182285 0.22664733 0.22596094 0.22556888]
 [0.3428513  0.21851634 0.2197266  0.21890575]
 [0.33973438 0.21924564 0.22127436 0.21974561]
 [0.30696116 0.23028301 0.23012593 0.2326299 ]]
[[0.3259688  0.2240804  0.22521227 0.22473853]
 [0.33935622 0.22084368 0.22023361 0.21956649]
 [0.34575425 0.2171866  0.21629388 0.22076528]
 [0.34179073 0.21863267 0.21909769 0.2204789 ]
 [0.31649603 0.22788084 0.22814348 0.22747966]
 [0.33254871 0.22151188 0.22274687 0.22319253]
 [0.31805146 0.22821878 0.2265468  0.22718297]
 [0.3454997  0.21764732 0.21936651 0.21748647]
 [0.34438323 0.21855957 0.21859761 0.21845959]
 [0.32073122 0.22628882 0.22592132 0.22705864]]
