In [1]:
import numpy as np
import time

In [2]:
class Agent:
    def __init__(self):
        self.i = 0
        self.last_action = 0

        self.t = np.zeros(6, dtype=np.int32)
        self.reward = np.zeros(6, dtype=np.int32)
        self.ucb_arms = np.zeros(6, dtype=np.float32)

    def kl_divergence(self, p, q):
        if p == 0 or q == 0 or p == 1 or q == 1:
            return 0
        return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))

    def solve_q(self, rhs, p_a):
        if p_a == 1:
            return 1

        q_s = np.arange(p_a, 1, 0.01)
        lhs = np.zeros(q_s.shape)

        for j, q in enumerate(q_s):
            lhs[j] = self.kl_divergence(p_a, q)

        lhs_rhs = lhs - rhs
        lhs_rhs[lhs_rhs <= 0] = np.inf
        q = q_s[np.argmin(lhs_rhs)]

        return q

    def ucb(self):
        for k in range(6):
            p_a = self.reward[k] / self.t[k]
            rhs = (np.log(self.i) + 3 * np.log(np.log(self.i))) / self.t[k]
            self.ucb_arms[k] = self.solve_q(rhs, p_a)
        
    def get_action(self, wicket, runs_scored):
        action = None

        if self.i == 0:
            action = 0
        else:

            self.reward[self.last_action] += 1 - wicket
            self.t[self.last_action] += 1

            if self.i < 6:
                action = self.i
            else:
                self.ucb()
                action = np.argmax(self.ucb_arms)

        self.last_action = action
        self.i += 1
        return action
    
    class ROLLNUMBER_Q1:



  def get_action(self,wicket,runs_scored):
    action = None

    if (self.i == 0 ):
      action = 0
    else:
      self.reward[self.i] += 1 - wicket
      self.t[self.i] += 1

      if (self.i < 6):
        action = self.i

      else:
        self.ucb()
        action = np.argmax(self.ucb_arms)

    self.last_action = action
    self.i += 1
    return self.last_action



In [42]:
class Environment:
  def __init__(self,num_balls,agent):
    self.num_balls = num_balls
    self.agent = agent
    self.__run_time = 0
    self.__total_runs = 0
    self.__total_wickets = 0
    self.__runs_scored = 0
    self.__start_time = 0
    self.__end_time = 0
    self.__regret_w = 0
    self.__regret_s = 0
    self.__wicket = 0
    self.__regret_rho = 0
    self.__p_out = np.array([0.001,0.01,0.02,0.03,0.1,0.3])
    self.__p_run = np.array([1,0.9,0.85,0.8,0.75,0.7])
    self.__action_runs_map = np.array([0,1,2,3,4,6])
    self.__s = (1-self.__p_out)*self.__p_run*self.__action_runs_map
    self.__rho = self.__s/self.__p_out


  def __get_action(self):
    self.__start_time      = time. time()
    action          = self.agent.get_action(self.__wicket,self.__runs_scored)
    self.__end_time        = time. time()
    self.__run_time   = self.__run_time + self.__end_time - self.__start_time
    return action


  def __get_outcome(self, action):
    pout = self.__p_out[action]
    prun= self.__p_run[action]
    wicket = np.random.choice(2,1,p=[1-pout,pout])[0]
    runs = 0
    if(wicket==0):
      runs = self.__action_runs_map[action]*np.random.choice(2,1,p=[1-prun,prun])[0]
    return wicket, runs


  def innings(self):
    self.__total_runs = 0
    self.__total_wickets = 0
    self.__runs_scored = 0

    for ball in range(self.num_balls):
      action = self.__get_action()
      self.__wicket, self.__runs_scored   = self.__get_outcome(action)
      self.__total_runs     = self.__total_runs + self.__runs_scored
      self.__total_wickets  = self.__total_wickets + self.__wicket
      self.__regret_w       = self.__regret_w+ (self.__p_out[action]-np.min(self.__p_out))
      self.__regret_s       = self.__regret_s+ (np.max(self.__s) - self.__s[action])
      self.__regret_rho       = self.__regret_rho+ (np.max(self.__rho)-self.__rho[action])
    return self.__regret_w,self.__regret_s,self.__regret_rho, self.__total_runs, self.__total_wickets, self.__run_time


In [62]:
agent = Agent()
environment = Environment(1000, agent)
regret_w, regret_s, reger_rho, total_runs, total_wickets, run_time = (
    environment.innings()
)

print(regret_w, regret_s, reger_rho, total_runs, total_wickets, run_time)

0.45499999999999996 2929.4750000000445 88813.20000000077 7 2 0.0887610912322998
