<a href="https://colab.research.google.com/github/kretchmar/CS339_2023/blob/main/CoinFlipGame_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Solving Coin Flip Game using Reinforcement Learning

## Game Description

A player has three turns.  During each turn they can flip a coin or pass.  They accumulate money in their pot during each turn.  At the end of three turns, they get to keep the money in the pot.  

- If they "pass" during a turn, then the money that is currently in the pot remains there into the next turn. 
- If they "flip" during a turn, then a random 50/50 coin is flipped.  If it lands heads, then 1 dollar is added to the pot.  If it lands tails, then all the money is removed from the pot.  


Our goal is to figure out how to flip/pass at each turn to maximize our average winnings.  

## State
The time steps are $k = \{0,1,2,3\}$.   During each time step $x_k$ is the amount of money in the pot which can range from 0 to 3.    We let state $s$ be a function

$
s = k*4 + x_k
$

Thus there are 16 states even though some of them are unrealizable in the game (for instance state $s=2$ is $k=0$ and $x_0 = 2$ which is impossible since we can't start the game with 2 dollars in the pot).  
## Equations

We let $v[(k,x)]$ represent the value of being in state $x$ at turn $k$; that is, with $x$ dollars in the pot at turn $k$.   The value $v[0,0]$ is the expected value we can earn in this game starting at the start state with 0 dollars in the pot.  That is the ultimate value we seek.

Let $Q[(k,x),a]$ be the action-value function.  The state we are in is state $(k,x)$ and the action is $a$.   This is the expected value of starting in this state and taking this action.   

Let $P[(k,x)]$ be the policy -- the best action choice at state $(k,x)$.   We let 0=pass and 1=flip.  




In [1]:
import numpy as np

## Key Reinforcement Learning and System Functions

In [47]:

#-------------------------------------------------------------
def update (s,a):
  '''
  Compute next state given state s and action a
  s = number of dollars in pot
  a = 0 means pass (s' = s)
  a = 1 means flip (s' = flip outcome{0,s+1})
  return reward (0 for all trajectories)
  '''
  turn = s // 4 + 1
  pot = s % 4

  if a == 0:
    return pot + turn*4,0

  if np.random.random() < 0.5:
    return turn*4,0      # tails
  else:
    return pot+1+turn*4,0    # heads

#-------------------------------------------------------------
def e_greedy(s,policy,epsilon):
  '''
  Implements an e-greedy policy.
  With probability epsilon, it returns random action choice
  otherwise returns action choice specified by the policy

  s = current state
  policy = policy function (an array that is indexed by state)
  epsilon (0 to 1) a probability of picking exploratory random action
  '''
  r = np.random.random()
  if r > epsilon:
    return policy[s]
  else:
    return np.random.randint(0,2)

#-------------------------------------------------------------
def init ():
  '''
  Create totals, counts and policy defaults
  '''
  Q = np.zeros((16,2))
  P = np.ones(16).astype(int)
  return Q,P

#-------------------------------------------------------------
def SARSA (Q,policy,alpha,epsilon):
  '''
  Perform 1 unit of experience (1 trial, trajectory)
  using the SARSA learning algorithm
  '''
  k=0
  s=0   #  turn 0, pot 0
  state = k*4+s
  action = e_greedy(state,policy,epsilon)
  total_reward = 0
  reward = 0

  #print("\n==== TRIAL ====")
  #print("state,action,reward: ({0:d},{1:d},{2:d})".format(state,action,reward))

  while (k < 3):
    next_state,reward = update(state,action)
    total_reward += reward
    next_action = e_greedy(next_state,policy,epsilon)
    TDerror = reward + Q[next_state,next_action] - Q[state,action]
    Q[state,action] = Q[state,action] + alpha * TDerror
    state = next_state
    action = next_action
    #print("state,action,reward: ({0:d},{1:d},{2:d})".format(state,action,reward))
    k += 1

  # now we need to update last (terminal) state  
  reward = state % 4    # reward is amount in pot
  #print("state,action,reward: ({0:d},{1:d},{2:d})".format(state,action,reward))
  total_reward += reward
  TDerror = reward + 0 - Q[state,action]
  Q[state,action] = Q[state,action] + alpha * TDerror
  return total_reward

#-------------------------------------------------------------
def QLearning (Q,policy,alpha,epsilon):
  '''
  Perform 1 unit of experience (1 trial, trajectory)
  using the QLearning learning algorithm
  '''
  k=0
  s=0   #  turn 0, pot 0
  state = k*4+s
  total_reward = 0
  reward = 0

  #print("\n==== TRIAL ====")
  #print("state,action,reward: ({0:d},{1:d},{2:d})".format(state,action,reward))

  while (k < 3):
    action = e_greedy(state,policy,epsilon)
    next_state,reward = update(state,action)
    total_reward += reward

    optimal_action = policy[next_state]
    TDerror = reward + Q[next_state,optimal_action] - Q[state,action]
    Q[state,action] = Q[state,action] + alpha * TDerror
    state = next_state
    #print("state,action,reward: ({0:d},{1:d},{2:d})".format(state,action,reward))
    k += 1

  # now we need to update last (terminal) state  
  reward = state % 4    # reward is amount in pot
  #print("state,action,reward: ({0:d},{1:d},{2:d})".format(state,action,reward))
  total_reward += reward
  TDerror = reward + 0 - Q[state,action]
  Q[state,action] = Q[state,action] + alpha * TDerror
  return total_reward

#-------------------------------------------------------------
def policy_improvement(Q):
  '''
  Update value function V and policy P based on Q values
  '''
  V = np.max(Q,axis=1)
  P = np.argmax(Q,axis=1)
  return V,P


#-------------------------------------------------------------
def do_trials (Q,policy,n,alpha,epsilon):
  '''
  Perform n trials of learning 
  '''
  R = 0   # total reward
  for i in range(n):
    R += SARSA(Q,policy,alpha,epsilon)

  return R / n  

#-------------------------------------------------------------
def do_QL_trials (Q,policy,n,alpha,epsilon):
  '''
  Perform n trials of learning 
  '''
  R = 0   # total reward
  for i in range(n):
    R += QLearning(Q,policy,alpha,epsilon)

  return R / n  


#-------------------------------------------------------------
def assess (policy,trials):
  '''
  Assess the value of the current policy by completing #trials
  using the specified policy (no e-greedy random actions)
  Does not accrue learning experience nor change policy
  '''
  R = 0
  for i in range(trials):
  R += 
  policy_evaluation(totals,counts,policy,trials,0)
  Q = compute_Q(totals,counts)
  V,P = policy_improvement(Q)
  return V[0]

#-------------------------------------------------------------
def play_game (policy):
  '''
  Simulate one trajectory of experience
  Return list of tuples during trajectory
  Each tuple is (s,a,r) -> state / action / reward
  epsilon = probability of exploratory action
  '''
  k=0
  s=0   #  turn 0, pot 0
  reward = 0

  while (k < 3):
    a = policy[s]
    s,r = update(s,a)
    k += 1
    reward += r

  reward += s % 4
  # final reward = state value, final action = 0 (meaningless)
  return reward


### Do SARSA Learning
This next block does a real segment of SARSA learning
- Start with initial (blank) learning experience
- Do 10 iterations of 1000 trials of SARSA, after each iteration update Policy
- Extract policy, value function and Q values

In [37]:
Q,P = init()
m = 10
n = 100
epsilon = 0.1
alpha = 0.1

for i in range(m):
  print("\n*** Trial {0:d} ***".format(i))
  R = do_trials(Q,P,n,alpha,epsilon)
  V,P = policy_improvement(Q)
  print("Perf: ",R)

print(Q)
print(V)
print(P)






*** Trial 0 ***
Perf:  0.72

*** Trial 1 ***
Perf:  1.01

*** Trial 2 ***
Perf:  1.04

*** Trial 3 ***
Perf:  0.93

*** Trial 4 ***
Perf:  0.94

*** Trial 5 ***
Perf:  0.98

*** Trial 6 ***
Perf:  0.91

*** Trial 7 ***
Perf:  1.04

*** Trial 8 ***
Perf:  0.96

*** Trial 9 ***
Perf:  0.96
[[0.62710214 0.92007939]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]
 [0.43898676 0.5908592 ]
 [0.84296694 1.03989386]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.4552618 ]
 [0.45942847 0.78402406]
 [1.99035462 0.89143127]
 [0.         0.        ]
 [0.         0.        ]
 [0.71757046 1.        ]
 [1.72982966 2.        ]
 [0.         2.15271139]]
[0.92007939 0.         0.         0.         0.5908592  1.03989386
 0.         0.         0.4552618  0.78402406 1.99035462 0.
 0.         1.         2.         2.15271139]
[1 0 0 0 1 1 0 0 1 1 0 0 0 1 1 1]


### Assessment
We can also assess the current policy by conduction many non-learning trials.

In [45]:
value = assess(P,2000)
print(value)

ValueError: ignored

### Do Q Learning
This next block does a real segment of Q learning
- Start with initial (blank) learning experience
- Do 10 iterations of 1000 trials of Q Learning, after each iteration update Policy
- Extract policy, value function and Q values

In [42]:
Q,P = init()
m = 10
n = 100
epsilon = 0.1
alpha = 0.1

for i in range(m):
  print("\n*** Trial {0:d} ***".format(i))
  R = do_QL_trials(Q,P,n,alpha,epsilon)
  V,P = policy_improvement(Q)
  print("Perf: ",R)

print(Q)
print(V)
print(P)





*** Trial 0 ***
Perf:  1.14

*** Trial 1 ***
Perf:  1.05

*** Trial 2 ***
Perf:  0.85

*** Trial 3 ***
Perf:  0.86

*** Trial 4 ***
Perf:  0.97

*** Trial 5 ***
Perf:  0.88

*** Trial 6 ***
Perf:  1.02

*** Trial 7 ***
Perf:  0.89

*** Trial 8 ***
Perf:  1.0

*** Trial 9 ***
Perf:  1.05
[[0.74995479 1.07427487]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]
 [0.49506624 0.83722155]
 [1.0025149  1.39899835]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.50342282]
 [0.68118056 1.21495614]
 [1.99999621 1.18740748]
 [0.         0.        ]
 [0.         0.        ]
 [0.74581342 1.        ]
 [1.99999921 1.99999913]
 [0.         2.99087024]]
[1.07427487 0.         0.         0.         0.83722155 1.39899835
 0.         0.         0.50342282 1.21495614 1.99999621 0.
 0.         1.         1.99999921 2.99087024]
[1 0 0 0 1 1 0 0 1 1 0 0 0 1 0 1]


### Assessment
We can also assess the current policy by conduction many non-learning trials.

In [43]:
value = assess(P,2000)
print(value)

ValueError: ignored