In [9]:
import numpy as np
from env import Env, RandomEnv
from reward import random_reward
from _types import Reward, Policy


In [10]:
def optimize(env: Env, reward: Reward) -> Policy:
  state_vals = np.zeros(env.n_s)

  for i in range(10000):
    cond_p = env.transition_dist * (reward + env.discount * state_vals[None, None, :])
    new_vals = cond_p.sum(axis=2).max(axis=1)
    diff = state_vals - new_vals
    state_vals = new_vals
    if np.linalg.norm(diff, 2) < 1e-5:
      break
  
  return cond_p.sum(axis=2).argmax(axis=1)

In [11]:
e = RandomEnv()
r = random_reward(e)
optimize(e, r)

array([ 0,  0,  0,  0,  7,  0, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0, 10,
        0,  0,  0,  2,  0,  0,  0,  0, 14,  0,  0,  0,  0,  0,  1,  0,  0,
        8,  5,  0,  0,  0, 15,  0, 10,  7,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  7,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0, 11,  0, 12,  0,  0,  0,  0,  0,  0,  0, 15,  0,  0,  0,
        0,  0,  0,  4,  0,  0,  0, 14,  0,  0,  0,  0,  0,  0,  0,  0,  8,
        5,  0,  0,  0,  0,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0, 12,  0,  0,  0,  0,  0])

# Idea: have one episode per state, start in each state exactly once

In [12]:
def policy_returns(
  rewards: list[Reward],
  policy: Policy,
  env: Env,
  discount_thresh=1e-5,
) -> list[float]:
  # beyond this point, the discounts get so heavy that it's not worth computing
  steps_per_episode = round(np.log(discount_thresh) / np.log(env.discount))

  num_rs = len(rewards)

  # 2D array, first dim is different reward funcs, second dim is samples
  return_vals = [[] for _ in range(num_rs)]

  for episode_i in range(env.n_s):
    # init state - we want to have one episode for each possible starting state
    s = episode_i
    episode_rewards = [[] for _ in range(num_rs)] # same dims as return_vals

    for _ in range(steps_per_episode):
      # # sample action from policy
      # a = np.random.choice(env.actions, p=policy[s])
      a = policy[s]

      # next state
      s_next = np.random.choice(env.states, p=env.transition_dist[s, a])
      for i, r in enumerate(rewards):
        episode_rewards[i].append(r[s, a, s_next])
      s = s_next
    
    # at the end we compute the discounted return return
    for r_i, r_values in enumerate(episode_rewards): # for each return func
      return_val = 0 # accumulator for the return
      for i, r in enumerate(r_values):
        if i == 0: gamma_i = 1.0
        else: gamma_i *= env.discount
        return_val += gamma_i * r
      return_vals[r_i].append(return_val)

  return [sum(rs) / len(rs) for rs in return_vals]
  
def policy_return(reward: Reward, *args, **kwargs) -> Reward:
  return policy_returns([reward], *args, **kwargs)[0]


In [13]:
# check that the policies we converge to are always the same
for i in range(10):
  e = RandomEnv()
  r = random_reward(e)
  policies = np.array([optimize(e, r) for _ in range(30)])
  assert (policies == policies[0]).all()

In [17]:
# check that policy returns are reasonably low-variance
deviations = []
for _ in range(30):
  e = RandomEnv()
  r = random_reward(e)
  policy = optimize(e, r)
  Js = np.array([policy_return(r, policy, e)
                  for _ in range(30)])
  mean = Js.mean()
  if abs(mean) < 1: mean = 1
  std = Js.std()
  dev = abs(std/mean)
  print(f'{mean=} {std=} {dev=}')
  

mean=32.42517271610367 std=0.5659563702609113 dev=0.017454228392739884
mean=10.76435090610558 std=0.14315140119765785 dev=0.013298656133224144
mean=11.42376013264463 std=0.22046628957326722 dev=0.019298924961078354
mean=87.68633601550484 std=1.2304171763801754 dev=0.014032028618034754
mean=96.50735990152552 std=1.2189727385892422 dev=0.012630878513649751
mean=50.916927273651886 std=0.719411359637228 dev=0.014129119688836834
mean=191.39973099492278 std=1.5579669754479952 dev=0.008139859796821362
mean=67.14360707532835 std=1.1409119076049759 dev=0.01699211521843007
mean=83.88499562442665 std=0.5251940798961472 dev=0.0062608822470179015
mean=3.962639062114664 std=0.7573540693541323 dev=0.19112365710894952
mean=1 std=0.0 dev=0.0
mean=3.3853230921944184 std=2.906851593632727e-06 dev=8.586629738045066e-07
mean=10.977354849134247 std=0.1891860769315402 dev=0.017234213481443644
mean=1 std=6.338430682868219e-06 dev=6.338430682868219e-06
mean=1 std=3.6411193329891134e-06 dev=3.6411193329891134e-

In [23]:
# check that the best policy > random policy > worst policy
for i in range(200):
  if i % 50 == 0: print(f'{i=}')
  e = RandomEnv()
  r = random_reward(e)
  best = optimize(e, r)
  worst = optimize(e, -1*r)
  random = np.random.randint(0, e.n_a-1, size=(e.n_s))
  J_best = policy_return(r, best, e)
  J_worst = policy_return(r, worst, e)
  J_random = policy_return(r, random, e)

  J_best_w_slack = J_best*1.01 if J_best > 0 else J_best*.99
  J_random_w_slack = J_random*1.01 if J_random > 0 else J_random*.99
  if J_best_w_slack < J_random or J_random_w_slack < J_worst:
    print(i, J_best, J_random, J_worst)

i=0
i=50
i=100
i=150
