In [0]:
import numpy as np

# define global variables for the problem
p_h = 0.55
p_l = 0.45

prize = 1000
c_h = 50
c_l = 10

high = 0
low = 1

# number of actions
n_actions = 2
actions = [high, low]

# two outcomes
win = 0
lose = 1

n_outcomes = 2

# number of rounds
d = 3

# states encoded as 0, 1, ..., 2d
states = np.arange(0,2*d+1) 
n_states = len(states)

In [0]:
# the env

def game():
    p = {}

    for s in states:
        p[s] = {a: [] for a in range(n_actions)}

        is_terminal = lambda x: x == 0 or x == 2*d

        for a in range(n_actions):
            p[s][a] = {outcome: [] for outcome in range(n_outcomes)}

            prob = p_h if a == high else p_l
            cost = c_h if a == high else c_l

            if is_terminal(s):
                p[s][a][win] = [(1.0, s, 0.0, True)]
                p[s][a][lose] = [(1.0, s, 0.0, True)]

            else:
                s_win = s + 1
                s_loss = s - 1
                
                if is_terminal(s_win):
                  p[s][a][win] = [(prob, s_win, prize-cost, is_terminal(s_win))]
                else:
                  p[s][a][win] = [(prob, s_win, -cost, is_terminal(s_win))]                  
                p[s][a][lose] = [(prob, s_loss, -cost, is_terminal(s_loss))]   
    return p

In [0]:
# Initial state values - 0s
state_values_init = np.zeros(n_states)

pi_a = np.zeros((n_states,n_actions))

# Assume a random policy
#pi_a = 0.5*np.ones((n_states,n_actions))

def policy(x):
  pi_a = np.zeros((n_states,n_actions))
  for i in range(len(x)):
    pi_a[i+1, x[i]] = 1.0
  return pi_a

pi_a = policy([low, high, high, high, high])

# Undiscounted
gamma = 1.0
theta = 1e-4

In [0]:
# Get values

def get_values(pi_a, state_values_init):
  state_values = state_values_init
  transition_probs = game()
  #iteration_counter = 1
  while True:
    v_old = np.copy(state_values)
    delta = 0.0
    for s in range(1,n_states-1):
      v_s = 0.0
      for a in range(n_actions):
        for outcome in range(n_outcomes):
          current_entry = transition_probs[s][a][outcome][0]
          p_sa = current_entry[0]
          next_s = current_entry[1]
          r = current_entry[2]
          v_s += pi_a[s,a] * p_sa * (r + gamma * v_old[next_s])
          
      state_values[s] = v_s
      delta = np.maximum(delta, np.abs(state_values[s] - v_old[s]))
    #print('After %s iteration(s):\n' % iteration_counter, state_values[1:2*d])
    #print('delta = %s:\n' % delta)
    #iteration_counter += 1
    if delta < theta:
      break
  return state_values[1:2*d]

In [0]:
values = get_values(pi_a, state_values_init)
print(values)

[100.25208672 242.78249907 441.17057127 659.34598858 857.64024305]


In [0]:
import itertools
lst = list(itertools.product([0, 1], repeat=2*d-1))

dtype = [('pol', 'S10'), ('val', float)]
pol_val = np.zeros(32, dtype=dtype)

j = 0

for l in lst[::-1]:
  pol = ''.join(str(e) for e in l)
  l = np.asarray(l)
  pi_a = policy(l)
  values = get_values(pi_a, state_values_init)
  print(pol, values)
  s = pol_val[j]
  s['pol'] = pol
  s['val'] = values[d-1]
  j += 1

11111 [ 30.34619801  87.43592971 183.9557189  341.35442501 594.60952089]
11110 [ 42.97191118 115.4932154  233.67967859 423.79512744 728.08727128]
11101 [ 41.94124104 113.20269343 229.62039186 417.06469263 628.67915082]
11100 [ 63.98155711 162.18123801 316.42132436 560.97726057 803.53749332]
11011 [ 28.25692759  82.79309549 175.72766207 336.71159079 592.52025047]
11010 [ 48.55268945 127.89494402 255.65829726 436.93848834 735.31612126]
11001 [ 48.14372618 126.98599167 254.04748006 434.91836287 636.71330296]
11000 [ 90.08717301 220.19371779 419.23232676 642.04687631 848.12578197]
10111 [  4.0693834   29.04297776 148.73593711 321.481159   585.66655371]
10110 [ 23.03674086  71.19275747 206.40476593 407.48450015 719.11647508]
10101 [ 19.8413948   64.09189181 196.68943412 392.99554629 617.84803932]
10100 [ 54.2712817  140.60284823 301.37040058 549.10915306 797.01003418]
10011 [ -8.44775621   1.22711186 110.67880002 300.00690061 576.00313805]
10010 [ 25.31949326  76.26562243 213.34527479 411.6

In [0]:
print(pol_val)

[(b'11111', 183.9557189 ) (b'11110', 233.67967859)
 (b'11101', 229.62039186) (b'11100', 316.42132436)
 (b'11011', 175.72766207) (b'11010', 255.65829726)
 (b'11001', 254.04748006) (b'11000', 419.23232676)
 (b'10111', 148.73593711) (b'10110', 206.40476593)
 (b'10101', 196.68943412) (b'10100', 301.37040058)
 (b'10011', 110.67880002) (b'10010', 213.34527479)
 (b'10001', 195.20353801) (b'10000', 441.17022592)
 (b'01111', 162.93628698) (b'01110', 213.60410394)
 (b'01101', 206.4050542 ) (b'01100', 295.10978556)
 (b'01011', 141.33710975) (b'01010', 223.73745247)
 (b'01001', 213.34579601) (b'01000', 386.00917187)
 (b'00111', 109.65020916) (b'00110', 169.71788899)
 (b'00101', 152.26885863) (b'00100', 262.33289999)
 (b'00011',  32.94663039) (b'00010', 143.24441477)
 (b'00001',  90.42123823) (b'00000', 370.13401031)]


In [0]:
np.sort(pol_val, order=['val']) 

array([(b'00011',  32.94663039), (b'00001',  90.42123823),
       (b'00111', 109.65020916), (b'10011', 110.67880002),
       (b'01011', 141.33710975), (b'00010', 143.24441477),
       (b'10111', 148.73593711), (b'00101', 152.26885863),
       (b'01111', 162.93628698), (b'00110', 169.71788899),
       (b'11011', 175.72766207), (b'11111', 183.9557189 ),
       (b'10001', 195.20353801), (b'10101', 196.68943412),
       (b'10110', 206.40476593), (b'01101', 206.4050542 ),
       (b'10010', 213.34527479), (b'01001', 213.34579601),
       (b'01110', 213.60410394), (b'01010', 223.73745247),
       (b'11101', 229.62039186), (b'11110', 233.67967859),
       (b'11001', 254.04748006), (b'11010', 255.65829726),
       (b'00100', 262.33289999), (b'01100', 295.10978556),
       (b'10100', 301.37040058), (b'11100', 316.42132436),
       (b'00000', 370.13401031), (b'01000', 386.00917187),
       (b'11000', 419.23232676), (b'10000', 441.17022592)],
      dtype=[('pol', 'S10'), ('val', '<f8')])