In [15]:
import jax
import jax.numpy as jnp
from jax import random

In [289]:
class Bandit:
    def __init__(self, k: int):
        self.k = k
    
    def train(self, reward_func, seed: int, steps: int):
        raise NotImplementedException()
        
    def action_value(self):
        raise NotImplementedException()

## Eps-greedy bandit

In [290]:
class EpsGreedyBandit(Bandit):
    def __init__(self, k: int, eps: float):
        super().__init__(k)
        self.q = jnp.zeros(k)
        self.n = jnp.zeros(k)
        self.eps = eps
    
    def __next_key(self):
        self.key = random.split(self.key)[0]
    
    def _action(self, _):
        if random.normal(self.key) < self.eps:
            self.__next_key()
            a = random.randint(self.key, (1,), 0, self.k).item()
        else:
            a = int(jnp.argmax(self.q))
        self.__next_key()
        return a
    
    def train(self, reward_func, seed: int, steps: int):
        self.key = random.PRNGKey(seed)
        for t in range(steps):
            a = self._action(t)
            r = reward_func(a)
            self.n = self.n.at[a].set(self.n[a] + 1)
            self.q = self.q.at[a].set(self.q[a] + (r - self.q[a]) / self.n[a])
            
    def action_value(self):
        return self.q

In [291]:
def reward_func_generator(means, stds, seed: int):
    key = random.PRNGKey(seed)
    def func(a):
        nonlocal key
        mean = means[a]
        std = stds[a]
        value = random.normal(key) * std + mean
        key = random.split(key)[0]
        return value
    return func

In [293]:
reward_seed = 11
reward_func = reward_func_generator([0, 1, 2], [0.1, 0.3, 0.7], reward_seed)

eps = 0.5
bandit_seed = 2
steps = 100

eps_greedy = EpsGreedyBandit(3, eps)
eps_greedy.train(reward_func, bandit_seed, steps)

In [294]:
eps_greedy.action_value()

DeviceArray([-0.00528179,  1.050063  ,  2.0207396 ], dtype=float32)

## Optimistic initial values

In [242]:
class OptimisticInitialValueBandit(EpsGreedyBandit):
    def __init__(self, k: int, eps: float, initial_q: float):
        super().__init__(k, eps)
        
        self.q = initial_q * jnp.ones_like(self.q)

In [249]:
initial_value = 5.
optimistic_bandit = OptimisticInitialValueBandit(3, eps, initial_value)
optimistic_bandit.train(reward_func, bandit_seed, steps)

In [250]:
optimistic_bandit.q

DeviceArray([0.00525586, 0.9327261 , 1.996768  ], dtype=float32)

## Upper-confidence-bound action selection

In [279]:
class UpperConfidenceBoundAction(EpsGreedyBandit):
    def __init__(self, k: int, eps: float, c: float):
        super().__init__(k, eps)
        self.c = c
    
    def __action(self, t):
        self.q_ucb = jnp.zeros_like(self.q)
        for action in range(self.k):
            self.q_ucb = self.q_ucb.at[action].set(self.q[action] + self.c * jnp.sqrt(jnp.ln(t) / self.n[action]))
        return jnp.argmax(self.q_ucb)

In [283]:
c = 0.6
ucb_bandit = UpperConfidenceBoundAction(3, eps, 0.6)
ucb_bandit.train(reward_func, bandit_seed, steps)

In [284]:
ucb_bandit.q

DeviceArray([0.00236601, 0.9279048 , 2.2617695 ], dtype=float32)

## Gradient bandit algorithm

In [295]:
import random as standard_random

In [340]:
class GradientBandit(Bandit):
    def __init__(self, k: int, alpha: float):
        super().__init__(k)
        self.alpha = alpha

        self.h = jnp.ones(self.k) * (1 / self.k)
    
    def probability(self, a):
        return jnp.exp(self.h[a]) / sum([jnp.exp(self.h[a_dot]) for a_dot in range(self.k)])
    
    def train(self, reward_func, seed: int, steps: int):
        actions = list(range(self.k))
        r_sum = 0.
        for t in range(1, steps + 1):
            a = standard_random.choices(actions, self.h)[0]
            r = reward_func(a)
            r_sum += r
            self.h_dot = jnp.array(self.h.copy())
            self.h_dot = self.h_dot.at[a].set(self.h[a] + self.alpha * (r - r_sum / t) * (1 - self.probability(a)))
            for action in actions:
                if a == action:
                    continue
                self.h_dot = self.h_dot.at[action].set(self.h[action] + self.alpha * (r - r_sum / t) * self.probability(action))
            self.h = self.h_dot
    
    def action_value(self):
        return self. h

In [341]:
alpha = .1
gradient_bandit = GradientBandit(3, alpha)

gradient_bandit.train(reward_func, bandit_seed, steps)

In [342]:
gradient_bandit.action_value()

DeviceArray([-0.05833293,  0.14452372,  1.007199  ], dtype=float32)