In [1]:
import random
from dro import OnlineDRO
import numpy as np

In [2]:
def gen_pmf(actions, epsilon, exploit_idx):
    pmf = [epsilon / actions] * actions
    pmf[exploit_idx] += 1 - epsilon
    return pmf


def sample_custom_pmf(pmf):
    total = sum(pmf)
    scale = 1 / total
    pmf = [x * scale for x in pmf]
    draw = np.random.random()
    sum_prob = 0.0
    for index, prob in enumerate(pmf):
        sum_prob += prob
        if(sum_prob > draw):
            return index
    raise Exception("can't sample")

In [13]:

#scenario 1: policy exploring around best arm (3) and then (0) becomes the best

class DiscountedAverage:
    def __init__(self, tau):
        self.tau = tau
        self.sumr = 0
        self.n = 0
    
    def update(self, r):
        self.sumr = self.tau * self.sumr + r
        self.n = self.tau * self.n + 1
    
    def current(self):
        return self.sumr / self.n

policy = gen_pmf(4, 0.9, 3)

#change 0 to 0.1 and you'll see that the lb actually shrinks 
log = gen_pmf(4, 0, 3)

p0_reward = [0.1, 0.3, 0.3, 0.4]
p1_reward = [1, 0.3, 0.3, 0.1]

count = 1000
def run(policy, eval, r, c):
    np.random.seed(seed=10)
    ci = OnlineDRO.OnlineCressieReadLB(alpha=0.05, tau=0.999)
    avg = DiscountedAverage(tau=0.999)
    for i in range(c):
       p_action = sample_custom_pmf(policy)
       e_action = sample_custom_pmf(eval)
       
       w = 1 / policy[p_action]
       if p_action != e_action:
           w = 0
       ci.update(c=1, w=w, r=r[p_action])
       avg.update(r[p_action])
       ci.recomputeduals()
       lb = ci.duals[0][0]
       m = avg.current()
       step = count / 10
       if (i % step) == (step - 1):
        #    print(f's: {i} policy: {p_action} eval: {e_action} r: {r[p_action]} w:{w} ')
           print(f'[{i}] lb: {lb} mean: {m}')


run(policy, log, p0_reward, count)

10


[99] lb: 0.16796888339914703 mean: 0.26975835418006716
[199] lb: 0 mean: 0.2769446132834465
[299] lb: 0.25001547984195366 mean: 0.2760558877450576
[399] lb: 0 mean: 0.2821628429496422
[499] lb: 0 mean: 0.28758211394319344
[599] lb: 0 mean: 0.2863197574109059
[699] lb: 0 mean: 0.28768856769275714
[799] lb: 0 mean: 0.29084603194487235
[899] lb: 0 mean: 0.29148221993670476
[999] lb: 0 mean: 0.29030934087612753


10