<a href="https://colab.research.google.com/github/konkoniknik/RL_BartoSutton/blob/main/RL_BartoSutton_Part1_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Task: Perform some implementations of algos from Barto Sutton's Reinforcement Learning: An Introduction. These are implementations of Part I (i.e., non approximate mainly tabular solutions).

For our toy experiments we implement a simple gridworld with obstacles and some squares where the episode ends


# 0. The world

In [None]:
import numpy as np
import random

# We assume that for each action the relevant neighboring square has p(r,s"; a,s )=0.8
# and 0.2 for all other states. r=-1 everywhere except for the terminal state where it
# is 0

class world:

  def __init__(self, matrix, gamma=0.95, p=1):
    self.rewards = matrix
    print(self.rewards)
    self.gamma=gamma

    self.p_env=p

    self.s_start= [(0,0)]
    self.s_terminal=[(matrix.shape[0]-1,matrix.shape[1]-1),(matrix.shape[0]-1,matrix.shape[1]-2),
                     (matrix.shape[0]-2,matrix.shape[1]-1),(matrix.shape[0]-2,matrix.shape[1]-2)]

    self.a_space =[(0,1),(0,-1),(1,0),(-1,0),(0,0)]
    self.a_bounds=(0,0,9,9)



m=np.zeros([10,10])
m[:]=-1

#m[2,:5]=-100
m[5,:8]=-100

#m[5:-2,5]=-100
#m[4,5:]=-100

m[-1,-1],m[-2,-2]=0,0
m[-1,-2],m[-2,-1]=0,0



w= world(matrix=m)


print(w.s_start,w.s_terminal)

[[  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [-100. -100. -100. -100. -100. -100. -100. -100.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.    0.    0.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.    0.    0.]]
[(0, 0)] [(9, 9), (9, 8), (8, 9), (8, 8)]


###  General **Functions** and Imports

In [None]:
# Print the evaluation
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


def policy_changed(old_policy,new_policy, display_changes=False):
   if not np.array_equal(old_policy,new_policy):
        print(ii, "policy changed:")
        if display_changes:
          for i, row in enumerate(old_policy):
            for j, e in enumerate(row):
              if tuple(old_policy[i,j])!=tuple(new_policy[i,j]):
                 print(f"({i}_{j}): old:{old_policy[i,j]}-> {new_policy[i,j]}")


def calc_min_max_Q_n_greedy(Q, world, policy):
   # calculate the max value per state

   max_Q=np.zeros(world.rewards.shape)
   min_Q=np.zeros(world.rewards.shape)

   for i,row in enumerate(max_Q):
    for j,_ in enumerate(row):
      s=(i,j)
      state_vals=[]
      for a in world.a_space:
        key=f"{i}_{j}_{a[0]}_{a[1]}"
        cond= ((i==0) and (a==(-1,0))) or ((j==0) and (a==(0,-1))) or ((i==9) and (a==(1,0))) or ((j==9) and (a==(0,1)))

        if not cond:
          state_vals.append(Q[key])
        else:
          state_vals.append(-10000)

      max_Q[s]=max(state_vals)
      min_Q[s]=min(state_vals)# calculating for presentation of the progress on the heatmpas
      chosen_index=state_vals.index(max_Q[s])
      chosen_action = world.a_space[chosen_index]

      policy[s]=chosen_action


   return max_Q,policy,min_Q



def create_Q_heatmap(max_Q, min_Q):

  max_Q_rounded= np.round(max_Q,1)
  min_Q_rounded= np.round(min_Q,1)

  # Setting up the plotting environment
  plt.figure(figsize=(16, 6))

  # Creating the first heatmap
  plt.subplot(1, 2, 1)  # (rows, columns, subplot number)
  sns.heatmap(max_Q_rounded, annot=True, cmap='viridis')
  plt.title('Heatmap of Array 1')

  # Creating the second heatmap
  plt.subplot(1, 2, 2)
  sns.heatmap(min_Q_rounded, annot=False, cmap='magma')
  plt.title('Heatmap of Array 2')

  # Display the heatmaps
  plt.show()



def initialize_Q(world, zeros_flag=True):
  Q={}
  total_a_space=world.a_space#+[(0,0)]
  for i,row in enumerate(world.rewards):
    for j, _ in enumerate(row):
      key=str(i)+"_"+str(j)
      for a in total_a_space:
          cond= ((i==0) and (a==(-1,0))) or ((j==0) and (a==(0,-1))) or ((i==9) and (a==(1,0))) or ((j==9) and (a==(0,1)))
          a_key="_"+str(a[0])+"_"+str(a[1])
          if not cond:
            Q[key+a_key]=0 if zeros_flag else np.random.randn()

          # Q is 0 in terminal states
          if (i,j) in world.s_terminal:
            Q[key+a_key]=0

  return Q


def initialize_policy(world, init="random"):
  action_space_size=len(world.a_space)
  length, width=w.rewards.shape[0]-1, w.rewards.shape[1]-1


  if init=="random":
    random_indices=np.random.choice(action_space_size, size=[length+1,width+1])

    A= np.array(world.a_space)[random_indices]
    for i,row in enumerate(A):
      for j,e in enumerate(row):
        if i==0 and tuple(A[i,j])==(-1,0):
          A[i,j]=(1,0)

        if j==0 and tuple(A[i,j])==(0,-1):
          A[i,j]=(0,1)

  if init=="down":
    print("Down")
    A=np.full((world.rewards.shape[0],world.rewards.shape[1],2), (1,0))

  return A




# Unused here: all probabilistic transitions need to be incorporated
def calc_transition(world, policy,s, display=False, epsilon=0.1):
  random_num= random.randint(1,100)/100
  #print(policy[s], random_int)
  # Epsilon - greedy
  if random_num<=1-epsilon:
    action=policy[s]
  else:
    tmp_spaces=world.a_space.copy()
    tmp_spaces.remove(tuple(policy[s]))
    random_num2=random.randint(0,len(tmp_spaces)-1)
    action=tmp_spaces[random_num2]



  new_s_tnt= tuple(a+b if (a+b<=9) and (a+b>=0) else a for a,b in zip(s, action))

  random_int= random.randint(1,100)/100

  if tuple(action)==(0,0):
    new_s=new_s_tnt

  # adding environments randomness
  if tuple(action)==(0,1) or tuple(action)==(0,-1):
    if random_int<=world.p_env:
      new_s=new_s_tnt
    elif random_int<=world.p_env+(1-world.p_env)/2:
      new_s= tuple([new_s_tnt[0]+1 if 0<=new_s_tnt[0]+1<=9 else new_s_tnt[0], new_s_tnt[1]])
    else:
      new_s= tuple([new_s_tnt[0]-1 if 0<=new_s_tnt[0]-1<=9 else new_s_tnt[0], new_s_tnt[1]])

  if tuple(action)==(1,0) or tuple(action)==(-1,0):
    if random_int<=world.p_env:
      new_s=new_s_tnt
    elif random_int<=world.p_env+(1-world.p_env)/2:
      new_s= tuple([new_s_tnt[0], new_s_tnt[1]+1 if 0<=new_s_tnt[1]+1<=9 else new_s_tnt[1]])
    else:
      new_s= tuple([new_s_tnt[0], new_s_tnt[1]-1 if 0<=new_s_tnt[1]-1<=9 else new_s_tnt[1]])

  if display:
    print("old s", s, "new s", new_s, "policy s",policy[s], "action",action,"random env:", random_int, "random num", random_num, "epsilon", epsilon)

  return new_s, action, world.rewards[new_s]



def policy_visualization(policy):
  # Create a simple representation for a policy
  policy_vis=np.full((10,10), "")
  transform={(1,0):"V", (-1,0):"A", (0,1):">",(0,-1):"<", (0,0):"O"}
  for i, row in enumerate(policy):
    for j, val in enumerate(row):
      policy_vis[i,j]=transform[tuple(policy[i,j])]

  print(policy_vis)




def episode_visualization(episode):
  # Create a simple representation for a policy
  episode_vis=np.full((10,10), " ")
  for i, e in enumerate(episode):
      episode_vis[e[1][0],e[1][1]]="X"

  print(episode_vis)

# 1. Dynamic Programming

## 1.1. Policy Iteration

In [None]:

#V=np.zeros(w.rewards.shape)
V = np.random.normal(0,1,w.rewards.shape)
length, width=w.rewards.shape[0]-1, w.rewards.shape[1]-1
bounds=(0,0,length, width)

action_space=[(0,1),(0,-1),(1,0),(-1,0)]
s=(0,0)

print(V,V[s], tuple(a+b for a,b in zip(s,action_space[0])), bounds)

# Create an arbitrary detterministic policy (we let the environment explore for now)
action_space_size=len(action_space)

random_indices=np.random.choice(action_space_size, size=[length+1,width+1])
policy = np.array(action_space)[random_indices]

print("Policy:", policy[0,0])


def calc_expected_value_n_states(V,w,action,s):

  S=0
  new_s= tuple(a+b if (a+b<=9) and (a+b>=0) else a for a,b in zip(s, action))


  if tuple(action)==(0,1) or tuple(action)==(0,-1):
    new_s_adj1= tuple([new_s[0]+1 if 0<=new_s[0]+1<=9 else new_s[0], new_s[1]])
    new_s_adj2= tuple([new_s[0]-1 if 0<=new_s[0]-1<=9 else new_s[0], new_s[1]])

  if tuple(action)==(1,0) or tuple(action)==(-1,0):
    new_s_adj1= tuple([new_s[0], new_s[1]+1 if 0<=new_s[1]+1<=9 else new_s[1]])
    new_s_adj2= tuple([new_s[0], new_s[1]-1 if 0<=new_s[1]-1<=9 else new_s[1]])


  S=w.p_env* (w.rewards[new_s] + gamma*V[new_s])+((1-w.p_env)/2)* (w.rewards[new_s_adj1] +  gamma*V[new_s_adj1]) \
   + ((1-w.p_env)/2)* (w.rewards[new_s_adj2] + gamma*V[new_s_adj2])


  return S, (new_s, new_s_adj1, new_s_adj2)


theta= 0.0000001
for ii in range(100):
  print("\n\n ------------------------- ",ii," ------------------------------")
  delta=30

  cnt=0
  # Policy Evaluation:
  while delta> theta:

    cnt+=1
    #print("cnt",cnt, V[8,9])

    delta=0

    for i, row in enumerate(V):
      for j , element in enumerate(row):

        s=(i,j)
        v_old= V[s]

        if s!=w.s_terminal[0]:
          V[s],_=calc_expected_value_n_states(V, w, policy[s], s)
        else:
          V[s]=0

        delta = np.max([delta, np.abs(V[s]-v_old)])

        #print(w.s_terminal,s, V[s])

  #print("V is: ",V)
  #policy_visualization(policy)

  ## Policy Improvement:
  policy_stable=True
  for i, row in enumerate(policy):
      for j , element in enumerate(row):
          s=(i,j)
          old_action= policy[s].copy()
          expected_values=[]
          for action in action_space:
            val,_ = calc_expected_value_n_states(V,w,action,s)
            expected_values.append(val)

          max_val= max(expected_values)
          chosen_index=expected_values.index(max_val)
          chosen_action = action_space[chosen_index]
          policy[s]=chosen_action
          #print("Chosen action",chosen_action)

          if tuple(old_action) != chosen_action:
            policy_stable =False



  # Create the heatmap
  V_rounded= np.round(V,1)

  sns.heatmap(w.rewards, annot=True, cmap='viridis')
  plt.show()


  sns.heatmap(V_rounded,  cmap='viridis')

  # Add titles and labels as needed
  plt.title('Heatmap of 2D Array')
  plt.xlabel('X-axis Label')
  plt.ylabel('Y-axis Label')

  # Show the plot
  plt.show()
  policy_visualization(policy)
  print(V_rounded)
  #breakpoint()
  if policy_stable:
    break





## 1.2. Value *Iteration*

In [None]:
#V=np.zeros(w.rewards.shape)
V = np.random.normal(0,1,w.rewards.shape)
length, width=w.rewards.shape[0]-1, w.rewards.shape[1]-1
bounds=(0,0,length, width)

action_space=[(0,1),(0,-1),(1,0),(-1,0)]
s=(0,0)

print(V,V[s], tuple(a+b for a,b in zip(s,action_space[0])), bounds)

# Create an arbitrary detterministic policy (we let the environment explore for now)
action_space_size=len(action_space)

random_indices=np.random.choice(action_space_size, size=[length+1,width+1])
policy = np.array(action_space)[random_indices]

print("Policy:", policy[0,0])


def calc_expected_value_n_states(V,w,action,s):

  S=0
  new_s= tuple(a+b if (a+b<=9) and (a+b>=0) else a for a,b in zip(s, action))


  if tuple(action)==(0,1) or tuple(action)==(0,-1):
    new_s_adj1= tuple([new_s[0]+1 if 0<=new_s[0]+1<=9 else new_s[0], new_s[1]])
    new_s_adj2= tuple([new_s[0]-1 if 0<=new_s[0]-1<=9 else new_s[0], new_s[1]])

  if tuple(action)==(1,0) or tuple(action)==(-1,0):
    new_s_adj1= tuple([new_s[0], new_s[1]+1 if 0<=new_s[1]+1<=9 else new_s[1]])
    new_s_adj2= tuple([new_s[0], new_s[1]-1 if 0<=new_s[1]-1<=9 else new_s[1]])


  S=w.p_env* (w.rewards[new_s] + gamma*V[new_s])+((1-w.p_env)/2)* (w.rewards[new_s_adj1] +  gamma*V[new_s_adj1]) \
   + ((1-w.p_env)/2)* (w.rewards[new_s_adj2] + gamma*V[new_s_adj2])


  return S, (new_s, new_s_adj1, new_s_adj2)


theta= 0.0001
delta=30
ii=0

# Policy Evaluation:
while delta> theta:
  ii+=1
  print("\n\n ------------------------- ",ii," ------------------------------")


  delta=0
  for i, row in enumerate(V):
    for j , element in enumerate(row):
      s=(i,j)
      v_old= V[s]
      if s!=w.s_terminal[0]:
        expected_values=[]
        for action in action_space:
          val,_ = calc_expected_value_n_states(V,w,action,s)
          expected_values.append(val)

        max_val= max(expected_values)
        V[s]=max_val
      else:
        V[s]=0

      delta = np.max([delta, np.abs(V[s]-v_old)])

        #print(w.s_terminal,s, V[s])


  # Create the heatmap
  V_rounded= np.round(V,1)

  sns.heatmap(w.rewards, annot=True, cmap='viridis')
  plt.show()

  sns.heatmap(V_rounded,  cmap='viridis')

  # Add titles and labels as needed
  plt.title('Heatmap of 2D Array')
  plt.xlabel('X-axis Label')
  plt.ylabel('Y-axis Label')

  # Show the plot
  plt.show()
  print(V_rounded)


# Based on the value function choose the policy:
for i, row in enumerate(V):
    for j , element in enumerate(row):
      s=(i,j)
      expected_values=[]
      for action in action_space:
        val,_ = calc_expected_value_n_states(V,w,action,s)
        expected_values.append(val)

      max_val= max(expected_values)
      chosen_index=expected_values.index(max_val)
      chosen_action = action_space[chosen_index]
      policy[s]=chosen_action



policy_visualization(policy)




## 2 Monte Carlo


In [None]:
def initialize_returns(world):
  returns={}
  for i,row in enumerate(world.rewards):
    for j, _ in enumerate(row):
      key=str(i)+"_"+str(j)
      for a in world.a_space:
        a_key="_"+str(a[0])+"_"+str(a[1])
        returns[key+a_key]=[]

  return returns





def generate_episode(world, policy, Q, max_cnt=100000, epsilon=0.2, mode="ES"):
  #handle episodes that are too large
  cnt=max_cnt
  episode=[]
  outer_cnt=0
  while cnt>=max_cnt or len(episode)<=2:
    cnt=0
    # generate a random starting state
    if mode=="ES":
      s0=tuple(np.random.randint((10,10)))
    else:
      s0=world.s_start[0]


    episode=[(None,s0,tuple(policy[s0]))]
    s=s0
    episode_history={key:0 for key in Q.keys()}

    while (s not in world.s_terminal) and cnt<=max_cnt:
      old_s=s
      s,a,r=calc_transition(world,policy,s, epsilon=epsilon)

      if (old_s[0]==0) and (tuple(a)==(-1,0)):
        a=(1,0)

      if ((old_s[1]==0) and tuple(a)==(0,-1)):
        a=(0,1)
        #print("HEY",s,a)



      #print(s,a)
      episode.append((r,old_s,a))
      episode_history[f"{old_s[0]}_{old_s[1]}_{a[0]}_{a[1]}"]+=1
      #print(s,a)

      cnt+=1

    outer_cnt+=1

    if outer_cnt>1000:
      print("BADPOLICY Reinit Policy")
      policy=initialize_policy(world)

    #print("\n\n")
    #print(outer_cnt, s0,s, len(episode))#,episode)
    #policy_visualization(policy)
    #episode_visualization(episode)



  return episode, episode_history


### 2.1 Exploring Starts & on-policy

In [None]:

Q = initialize_Q(w, True)
policy=initialize_policy(w,init="random")# "down"
returns=initialize_returns(w)

step_size=1000
val_epsilon=0.4
print(Q["0_0_1_0"])
max_Q,min_Q=np.zeros(w.rewards.shape),np.zeros(w.rewards.shape)


for ii in range(100000):
  episode, episode_history=generate_episode(w, policy, Q, max_cnt=step_size,mode="on-policy", epsilon=val_epsilon)
  #print("\n\n",ii, len(episode))
  #create_Q_heatmap(max_Q, min_Q)
  #policy_visualization(policy)
  #episode_visualization(episode)

  if ii % 1000==0:
    create_Q_heatmap(max_Q, min_Q)
    policy_visualization(policy)
    step_size+=1000
    if val_epsilon>0.15:
        val_epsilon-=0.05

    #if val_epsilon<0:
    #  val_epsilon=0
    print(ii,"Policy Simulation:",w.p_env,val_epsilon, step_size)

  #print(Q["0_0_0_-1"],Q["0_0_-1_0"],Q["0_0_1_0"],len(episode),episode)
  G=0
  for t  in range(len(episode)-2, 0,-1 ):
    G=w.gamma*G+episode[t+1][0]
    key=f"{episode[t][1][0]}_{episode[t][1][1]}_{episode[t][2][0]}_{episode[t][2][1]}"
    episode_history[key]=episode_history[key]-1
    #print(key,episode_history[key])
    if episode_history[key]<=0:
      returns[key].append(G)
      Q[key]=sum(returns[key])/len(returns[key])

      old_policy=policy.copy()
      max_Q,policy,min_Q=calc_min_max_Q_n_greedy(Q,w,policy)

      if not np.array_equal(old_policy,policy):
        print(ii, "policy changed:")
      # for i, row in enumerate(old_policy):
      #   for j, e in enumerate(row):
      #     if tuple(old_policy[i,j])!=tuple(policy[i,j]):
      #        print(f"({i}_{j}): old:{old_policy[i,j]}-> {policy[i,j]}")




print(returns,"\n",Q)
create_Q_heatmap(max_Q, min_Q)
policy_visualization(policy)




In [None]:
ss="0_1_0_1"
print(Q[ss],len(returns[ss]),returns[ss][-5:] )

-65.82181905831247 32790 [-20.000000000000014, -20.000000000000014, -20.000000000000014, -20.000000000000014, -20.000000000000014]


### 2.2 Off-Policy (Not 100% its correct)

In [None]:
Q = initialize_Q(w, True)
C = initialize_Q(w, True)
Counters=initialize_Q(w, True)


pi_policy=initialize_policy(w,init="random")# "down"
max_Q, pi_policy, min_Q= calc_min_max_Q_n_greedy(Q,w, pi_policy)

print(pi_policy.shape)

b_policy=initialize_policy(w,init="random")



returns=initialize_returns(w)

step_size=1000

pi_epsilon=0.05 # the simler strategy
b_epsilon=0.5

max_Q,min_Q=np.zeros(w.rewards.shape),np.zeros(w.rewards.shape)



for ii in range(100000):
  episode, episode_history=generate_episode(w, b_policy, Q, max_cnt=step_size,mode="on-policy", epsilon=b_epsilon)
  #print(ii, len(episode))
  b_policy=initialize_policy(w,init="random")


  if ii % 1000==0:
    create_Q_heatmap(max_Q, min_Q)
    policy_visualization(pi_policy)
    step_size+=1000

    print(ii,"Policy Simulation:",w.p_env, step_size, {k:round(v,2) for k,v in Q.items()})
    print("Counters", Counters)

  W=1
  G=0
  for t  in range(len(episode)-2, 0,-1 ):

    G=w.gamma*G+episode[t+1][0]
    key=f"{episode[t][1][0]}_{episode[t][1][1]}_{episode[t][2][0]}_{episode[t][2][1]}"

    C[key]+=W
    Counters[key]+=1
    if C[key]==0:
      print("C of key 0.. Breaking")
      break

    Q[key]+=(W/C[key])*(G-Q[key])

    old_pi_policy=pi_policy.copy()
    max_Q, pi_policy, min_Q= calc_min_max_Q_n_greedy(Q,w, pi_policy)
    policy_changed(old_pi_policy,pi_policy)


    # Calculate importance sampling probability ratios per policy
    current_action=(episode[t][2][0],episode[t][2][1])

    greedy_action=tuple(pi_policy[episode[t][1][0], episode[t][1][1]])
    behaviour_action=tuple(b_policy[episode[t][1][0], episode[t][1][1]])



    if current_action==behaviour_action:
      p_b=1-b_epsilon
    else:
      p_b=b_epsilon/3

    if current_action==greedy_action:
      p_pi=1-pi_epsilon
    else:
      p_pi=pi_epsilon/3


    W=W*(p_pi/p_b)
    #print("W",key, W/C[key])

    # only if pi_policy i deterministic
    #if current_action!=greedy_action:
    #  #print("break episofe",t, len(episode))
    #  break

    #W=W*(1/p_b)




print(returns,"\n",Q)
create_Q_heatmap(max_Q, min_Q)
policy_visualization(policy)



In [None]:
ss="7_1_0_1"
print(Q[ss],len(returns[ss]),returns[ss][-5:] )

-9.85622450678161 0 []


## TD Learning

In [None]:
def transition_handler(world,policy,s,epsilon, display=False):

  new_s,a,r=calc_transition(world,policy,s, epsilon=epsilon, display=display)
  if (s[0]==0) and (tuple(a)==(-1,0)):
    a=(0,0)


  if ((s[1]==0) and tuple(a)==(0,-1)):
    a=(0,0)

  if ((s[0]==9) and tuple(a)==(1,0)):
    a=(0,0)


  if ((s[1]==9) and tuple(a)==(0,1)):
    a=(0,0)

  #print(new_s,a,r)


  return new_s, a, r





### All: On-Policy (Sarsa), off-policy (Q-learning), expected sarsa

In [None]:
Q=initialize_Q(w, True)
policy=initialize_policy(w,init="random")# "down"
max_Q, policy, min_Q= calc_min_max_Q_n_greedy(Q,w, policy)
policy_visualization(policy)



mode="ES"# "Q" for Q-learning, "sarsa for sarsa", else expected sarsa
print("Mode:",mode)

epsilon=0.15
max_cnt=10000
alpha=0.01

for i in range(20000):

  if i%100==0:
    print("Episode: ",i)
    if i%1000==0:
      create_Q_heatmap(max_Q, min_Q)
      policy_visualization(policy)


  max_Q, policy, min_Q= calc_min_max_Q_n_greedy(Q,w, policy)

  s=w.s_start[0]
  new_s,a,r=transition_handler(w,policy,s, epsilon=epsilon)

  cnt=0
  while (new_s not in w.s_terminal) and cnt<=max_cnt:
    key = f"{s[0]}_{s[1]}_{a[0]}_{a[1]}"

    s=new_s
    new_s,new_a,new_r=transition_handler(w,policy, s, epsilon=epsilon)


    if mode=="sarsa":
      new_key = f"{s[0]}_{s[1]}_{new_a[0]}_{new_a[1]}"
      Q[key]+=alpha*(r+w.gamma*Q[new_key] - Q[key])
    elif mode=="Q":
      c_l=[key for key in Q.keys() if key.startswith(f"{s[0]}_{s[1]}")]
      Qs=[Q[c_key] for c_key in c_l]

      Q[key]+=alpha*(r+w.gamma* max(Qs) - Q[key])

    else: # expected sarsa
      chosen_key=f"{s[0]}_{s[1]}_{policy[s][0]}_{policy[s][1]}"
      c_l=[k for k in Q.keys() if k.startswith(f"{s[0]}_{s[1]}")]
      S=0
      for k in c_l:
        if k!=chosen_key:
          S+=(epsilon/(len(c_l)-1))*Q[k]
        else:
          S+=(1-epsilon)*Q[k]

      Q[key]+=alpha*(r+w.gamma*S - Q[key])



    a,r= new_a, new_r


    cnt+=1







##  N-Step *Bootstrapping*

### On-Policy (n-step Sarsa)

In [None]:
Q=initialize_Q(w, True)
policy=initialize_policy(w,init="random")# "down"
max_Q, _, min_Q= calc_min_max_Q_n_greedy(Q,w, policy)
policy=initialize_policy(w,init="random")# "down"

init_mode="random"# random


alpha=0.01
epsilon=0.15
n=5


for i in range(30000):

  if i%100==0:
    print("Episode: ",i)
    if i%1000==0:
      create_Q_heatmap(max_Q, min_Q)
      policy_visualization(policy)

  if init_mode=="start":
    s=w.s_start[0]
  else:
    s=(random.randint(0,9),random.randint(0,9))
    while s==w.s_terminal:
     s=(random.randint(0,9),random.randint(0,9))

  T=10000000000
  t=0
  episode=[]
  new_s,a,r =transition_handler(w,policy,s,epsilon)
  while 1:

    if t<T:
        #print(new_s)
        episode.append((s,a,r))
        if new_s in w.s_terminal:
          T=t+1
        else:
          s=new_s
          new_s,a,r =transition_handler(w,policy,s,epsilon)

    T_small= t-n+1

    if T_small>=0:
      G=0
      for tt in range(T_small+1, min(T_small+n+1,T+1)):
        G+=(w.gamma**(tt-T_small-1))*episode[tt-1][2]
        #print(T_small+1, tt, T_small+n+1, T+1)


      if T_small+n<T:
        current_key=f"{s[0]}_{s[1]}_{a[0]}_{a[1]}"
        G+=(w.gamma**n)*Q[current_key]

      update_key=f"{episode[T_small][0][0]}_{episode[T_small][0][1]}_{episode[T_small][1][0]}_{episode[T_small][1][1]}"
      Q[update_key]+=alpha*(G-Q[update_key])
      max_Q, policy, min_Q= calc_min_max_Q_n_greedy(Q,w, policy)


    #print("Current",s,a,"Episode:",episode)
    t+=1
    if T_small == T-1:
      break



### Off-Policy (Too much variance)

In [None]:
Q=initialize_Q(w, True)
pi_policy=initialize_policy(w,init="random")# "down"
max_Q, _, min_Q= calc_min_max_Q_n_greedy(Q,w, pi_policy)
#pi_policy=initialize_policy(w,init="random")# "down"

b_policy=initialize_policy(w,init="random")# "down"

init_mode="random"#"start"# random


alpha=0.01
pi_epsilon=0.15
b_epsilon=0.5
n=5
w.gamma=0.9

for i in range(30000):

  if i%1==0:
    print("Episode: ",i)
    if i%100==0:
      create_Q_heatmap(max_Q, min_Q)
      policy_visualization(pi_policy)

  if init_mode=="start":
    s=w.s_start[0]
  else:
    s=(random.randint(0,9),random.randint(0,9))
    while s==w.s_terminal:
     s=(random.randint(0,9),random.randint(0,9))

  T=10000000000
  t=0
  episode=[]
  new_s,a,r =transition_handler(w,b_policy,s,b_epsilon)
  while 1:

    if t<T:
        #print(new_s)
        episode.append((s,a,r))
        if new_s in w.s_terminal:
          T=t+1
        else:
          s=new_s
          new_s,a,r =transition_handler(w,b_policy,s,b_epsilon)

    T_small= t-n+1

    if T_small>=0:
      ro=1
      for tt in range(T_small+1, min(T_small+n,T)):
        current_action=tuple(episode[tt-1][1])

        greedy_action=tuple(pi_policy[episode[tt-1][0][0], episode[tt-1][0][1]])
        behaviour_action=tuple(b_policy[episode[tt-1][0][0], episode[tt-1][0][1]])

        if current_action==behaviour_action:
          p_b=1-b_epsilon
        else:
          p_b=b_epsilon/4

        if current_action==greedy_action:
          p_pi=1-pi_epsilon
        else:
          p_pi=pi_epsilon/4


        ro=ro*(p_pi/p_b)



      G=0
      for tt in range(T_small+1, min(T_small+n+1,T+1)):
        G+=(w.gamma**(tt-T_small-1))*episode[tt-1][2]
        #print(T_small+1, tt, T_small+n+1, T+1)


      if T_small+n<T:
        current_key=f"{s[0]}_{s[1]}_{a[0]}_{a[1]}"
        G+=(w.gamma**n)*Q[current_key]

      update_key=f"{episode[T_small][0][0]}_{episode[T_small][0][1]}_{episode[T_small][1][0]}_{episode[T_small][1][1]}"


      Q[update_key]+=alpha*ro*(G-Q[update_key])
      max_Q, pi_policy, min_Q= calc_min_max_Q_n_greedy(Q,w, pi_policy)

      #print(ro,G,Q["3_3_1_0"])

    #print("Current",s,a,"Episode:",episode)
    t+=1
    if T_small == T-1:
      break



:### Off policy no importance sampling (tree backup)

In [None]:
Q=initialize_Q(w, True)
pi_policy=initialize_policy(w,init="random")# "down"
max_Q, _, min_Q= calc_min_max_Q_n_greedy(Q,w, pi_policy)
#pi_policy=initialize_policy(w,init="random")# "down"

b_policy=initialize_policy(w,init="random")# "down"

init_mode="start"#"start"# random


alpha=0.01
pi_epsilon=0.1
b_epsilon=0.3
n=5
episode=[]
#w.gamma=0.95
for i in range(30000):

  if i%10==0:
    print("Episode: ",i,len(episode))#, "0_0_0_1",Q['0_0_0_1'], "0_0_1_0",Q['0_0_1_0'])
    if i%100==0:
      create_Q_heatmap(max_Q, min_Q)
      policy_visualization(pi_policy)

  if init_mode=="start":
    s=w.s_start[0]
  else:
    s=(random.randint(0,9),random.randint(0,9))
    while s==w.s_terminal:
     s=(random.randint(0,9),random.randint(0,9))

  T=10000000000
  t=0
  episode=[]
  new_s,a,r =transition_handler(w,b_policy,s,b_epsilon)
  while 1:
    if t<T:
        #print(new_s)
        episode.append((s,a,r))
        if new_s in w.s_terminal:
          T=t+1
        else:
          s=new_s
          new_s,a,r =transition_handler(w,b_policy,s,b_epsilon)

    T_small= t-n+1

    if T_small>=0:
      if t+1>=T:
        G=r
      else:
        G=r
        G1=0
        for aa in w.a_space:
          p_pi = pi_epsilon/4
          if aa == tuple(pi_policy[s]):
            p_pi=1-pi_epsilon
          cond= ((s[0]==0) and (aa==(-1,0))) or ((s[1]==0) and (aa==(0,-1))) or ((s[0]==9) and (aa==(1,0))) or ((s[1]==9) and (aa==(0,1)))
          if not cond:
            key=f"{s[0]}_{s[1]}_{aa[0]}_{aa[1]}"
            G1+=p_pi* Q[key]

        G1=w.gamma*G1
        G=r+G1

      for tt in range(min(T-1, t), T_small,-1):
        #print(i,"new_s:",new_s,"s:",s,"tt:",tt,"Range:",min(T-1, t), T_small,"Episode Length:",len(episode),"Test:",T_small,"current", episode[tt][0],"To Update:",episode[T_small][0],"Episode:",episode)
        G1=0
        p_pi_chosen=1-pi_epsilon
        for aa in w.a_space:
          p_pi_not_chosen = pi_epsilon/4
          if aa != tuple(pi_policy[episode[tt][0]]):
            key=f"{episode[tt][0][0]}_{episode[tt][0][1]}_{aa[0]}_{aa[1]}"
            cond= ((episode[tt][0][0]==0) and (aa==(-1,0))) or ((episode[tt][0][1]==0) and (aa==(0,-1))) or ((episode[tt][0][0]==9) and (aa==(1,0))) or ((episode[tt][0][1]==9) and (aa==(0,1)))
            if not cond:
              G1+=p_pi_not_chosen* Q[key]


        G=episode[tt][2] + w.gamma*G1 + w.gamma*p_pi_chosen*G



      update_key=f"{episode[T_small][0][0]}_{episode[T_small][0][1]}_{episode[T_small][1][0]}_{episode[T_small][1][1]}"


      Q[update_key]+=alpha*(G-Q[update_key])
      max_Q, pi_policy, min_Q= calc_min_max_Q_n_greedy(Q,w, pi_policy)

      #print(a,T_small,update_key,Q[update_key])

    #print("Current",s,a,"Episode:",episode)
    t+=1
    if T_small == T-1:
      break

