<a href="https://colab.research.google.com/github/tianyuehz/Thompson_Sampling_MDP/blob/main/OTS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from typing_extensions import DefaultDict
from numpy.core.fromnumeric import argmax
import matplotlib.pyplot as plt
from tqdm import tqdm

# Environment / Util

In [10]:
class MDP():
  def __init__(self, S, A, H, epLen):
        self.S = S
        self.A = A
        self.H = H
        self.epLen = epLen
        self.P = np.random.dirichlet(np.ones(S)*0.1,(S,A,H))
        self.R = np.ones((S,A,H)) * 0.9
        self.R[np.random.random_sample((S,A,H)) <= 0.9] = 0 # To keep reward sparse
        self.curr_s = np.random.choice(S)
        self.total_reward = 0 # Cumulative reward within one episode

        # Observation statistics
        self.emperical_rewards = np.zeros((S,A,H)) 
        self.observations = np.ones((S,A,H)) 
        self.empirical_transition = np.zeros((S,A,H,S))

  def step(self,a,t): 
    s = self.curr_s
    r = self.R[s,a,t]
    next_s = np.random.choice(self.S, p=self.P[self.curr_s,a,t])
    self.update_stats(s,a,t,next_s, r)
    self.curr_s = next_s
    self.total_reward = self.total_reward + r

  def update_stats(self, s, a, t, next_s,r):
    self.emperical_rewards[s, a, t] += r
    self.observations[s, a, t] += 1
    self.empirical_transition[s,a,t,next_s] += 1

  def reset(self): # reset statistics at the end of each episode
    self.curr_s = np.random.choice(self.S)
    self.total_reward = 0
    
      
  def run_fixed_policy(self, p): # No learning, this works for either random, or optimal
    r = []
    for i in range(1000):
      for t in range(self.H):
        a = p[self.curr_s,t]
        self.step(a,t)
      r.append(self.total_reward)
      m.reset()
    return r
  

In [3]:
# If we have the access to the true reward and transition, we should be able to do good
def optimal_policy(m): 
  S = m.S
  H = m.H
  A = m.A
  policy = np.zeros((S,H),dtype=int)
  V = np.zeros((S,H+1))
  V[:,H] = np.zeros(S)
  Q = np.zeros((S,A,H))
  for t in range(H-1,-1,-1):
    for s in range(S):
      Q[s,:,t] = m.R[s,:,t] + np.dot(m.P[s,:,t,:], V[:, t+1])
      policy[s,t] = np.argmax(Q[s, :, t])
      V[s,t] = max(Q[s,:, t])
  return policy

In [4]:
def get_avg(r):
  window = 1000
  average_r = []
  for ind in range(len(r) - window + 1):
    average_r.append(np.mean(r[ind:ind+window]))
  for ind in range(window - 1):
    average_r.insert(0, np.nan)
  return average_r

# Algorithms

## OTS

In [5]:
# OTS
def OTS(m):
    S = m.S
    H = m.H
    A = m.A
    T = m.epLen
    Q = np.zeros((S,A,H))
    V = np.zeros((S,H+1))
    V[:,H] = np.zeros(S)
    delta = 1/(S * A * (H**2) * (T**2))
    policy = np.random.choice(m.A,(m.S,m.H))
    for t in range(H-1,-1,-1):
      for s in range(S):
        for a in range(A):

          # Compute param mu and sigma for the posterior distribution
          mu = m.emperical_rewards[s, a, t] / m.observations[s, a, t]
          P = m.empirical_transition[s,a,t,:] 
          norm_P = np.linalg.norm(P)
          if norm_P != 0:
            P = P / np.linalg.norm(P)
          sigma = min(H, np.sqrt((S * (H ** 3) * np.log(1/delta))/(m.observations[s,a,t])))

          # Sample reward, clip to empirical mean
          r = max(mu, np.random.normal(mu, sigma))

          # Compute Q 
          Q[s,a,t] = r + np.dot(P, V[:, t+1]) 
        
        # Update policy to the best arm in this round
        policy[s,t] = np.argmax(Q[s, :, t])
        V[s,t] = max(Q[s,:,t])
    return policy

def run_OTS(m):
  r = []
  p_rand = np.random.choice(m.A,(m.S,m.H))
  for i in tqdm(range(m.epLen)):
    if i< 1000:
      # Warm start for the first 1000 episode
      for t in range(m.H):
        a = p_rand[m.curr_s,t]
        m.step(a,t)
      r.append(m.total_reward)
      m.reset()
    else:
      # Update policy based on new stats
      p = OTS(m)
      # Sampling
      for t in range(m.H):
        a = p[m.curr_s,t]
        m.step(a,t)
      r.append(m.total_reward)
      m.reset()
  return r

## OTS+

In [6]:
# OTS
def OTS_plus(m):
    S = m.S
    H = m.H
    A = m.A
    T = m.epLen
    Q = np.zeros((S,A,H))
    V = np.zeros((S,H+1))
    V[:,H] = np.zeros(S)
    delta = 1/(S * A * (H**2) * (T**2))
    policy = np.random.choice(m.A,(m.S,m.H))
    for t in range(H-1,-1,-1):
      for s in range(S):
        for a in range(A):

          # Compute param mu and sigma for the posterior distribution
          mu = m.emperical_rewards[s, a, t] / m.observations[s, a, t]
          P = m.empirical_transition[s,a,t,:] 
          norm_P = np.linalg.norm(P)
          if norm_P != 0:
            P = P / np.linalg.norm(P)
          sigma = min(H, np.sqrt((S * (H ** 3) * np.log(1/delta))/(m.observations[s,a,t])))

          # Sample reward, clip to upper confidence bound
          ucb = mu + 2* sigma
          r = max(ucb, np.random.normal(mu, sigma))

          # Compute Q 
          Q[s,a,t] = r + np.dot(P, V[:, t+1]) 
        
        # Update policy to the best arm in this round
        policy[s,t] = np.argmax(Q[s, :, t])
        V[s,t] = max(Q[s,:,t])
    return policy

def run_OTS_plus(m):
  r = []
  p_rand = np.random.choice(m.A,(m.S,m.H))
  for i in tqdm(range(m.epLen)):
    if i< 1000:
      # Warm start for the first 1000 episode
      for t in range(m.H):
        a = p_rand[m.curr_s,t]
        m.step(a,t)
      r.append(m.total_reward)
      m.reset()
    else:
      # Update policy based on new stats
      p = OTS(m)
      # Sampling
      for t in range(m.H):
        a = p[m.curr_s,t]
        m.step(a,t)
      r.append(m.total_reward)
      m.reset()
  return r

## RLSVI TODO

In [22]:
def RLSVI(m):
  return
# NARL
# UBEV
# UCB-VI

## UBEV TODO

In [None]:
def UBEV(m):
  return

## NARL TODO

In [None]:
def NARL(m):
  return

# Experiment

In [11]:
S = 3
A = 3
H = 5
epLen = 1000000
m = MDP(S,A,H,epLen)

In [None]:
# Sanity check: optimal should perform better than random
p1 = np.random.choice(m.A,(m.S,m.H))
p2 = optimal_policy(m)
r1 = m.run_fixed_policy(p1)
r2 = m.run_fixed_policy(p2)

r_ots = run_OTS(m) 
r_ots_plus = run_OTS_plus(m) 

# TODO: save results properly to avoid re-running
plt.plot(np.ones(epLen)*np.avg(r1),label="random")
plt.plot(np.ones(epLen)*np.avg(r2),label="optimal")
plt.plot(get_avg(r_ots),label="OTS")
plt.plot(get_avg(r_ots_plus),label="OTS+")
plt.legend()

 92%|█████████▏| 923558/1000000 [26:21<01:51, 684.55it/s]