# Sticky_TaS_fast, Sticky_TaS, Sticky_TaS_old

This notebook conducts comparison between three implementations of the sam Algorithm Sticky_TaS. The comparison mainly focuses on two parts.

1. Check whether their action is always the same, given the same environment
2. Compare the running speed of these implementations

## Check whether their actions are the same

In [1]:
from agent import Sticky_TaS_fast, Sticky_TaS, Sticky_TaS_old
from env import Environment_Gaussian
from tqdm import tqdm
from time import time
import numpy as np

In [2]:
# run the implementations with the same reward vector, different permutation
K = 20
xi = 0.5
Delta = 4
rlist = np.zeros(K)
rlist[-1] = xi + Delta
delta = 0.0001

n_exp = 100
for exp_id in tqdm(range(n_exp)):
    rlist_temp = rlist.copy()
    np.random.seed(exp_id)
    np.random.shuffle(rlist_temp)
    answer_set = list(np.where(rlist_temp > xi)[0] + 1)

    env = Environment_Gaussian(rlist=rlist_temp, K=K, random_seed=exp_id)
    agent_sas = Sticky_TaS(K=K, delta=delta, xi=xi, logC=1, log1_over_delta=1000)
    agent_sas_fast = Sticky_TaS_fast(K=K, delta=delta, xi=xi, logC=1, log1_over_delta=1000)
    agent_sas_old = Sticky_TaS_old(K=K, delta=delta, xi=xi, logC=1, log1_over_delta=1000)

    while (not agent_sas.stop) or (not agent_sas_fast.stop) or (not agent_sas_old.stop):
        arm_sas_fast = agent_sas_fast.action()
        arm_sas = agent_sas.action()
        arm_sas_old = agent_sas_old.action()
        assert arm_sas_fast == arm_sas, f"exp {exp_id}, round {agent_sas.t} inconsistent"
        assert arm_sas_old == arm_sas, f"exp {exp_id}, round {agent_sas.t} inconsistent"
        assert agent_sas.stop == agent_sas_fast.stop, f"exp {exp_id}, round {agent_sas.t} inconsistent"
        assert agent_sas_old.stop == agent_sas_fast.stop, f"exp {exp_id}, round {agent_sas.t} inconsistent"

        reward = env.response(arm_sas_fast)

        agent_sas.observe(reward)
        agent_sas_fast.observe(reward)
        agent_sas_old.observe(reward)

print("Didn't detect inconsistent action")

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [11:15<00:00,  6.75s/it]

Didn't detect inconsistent action





In [4]:
# run the implementations with randomly generalized reward vector
K = 10
xi = 0.5
Delta = 4
delta = 0.0001

n_exp = 100
for exp_id in tqdm(range(n_exp)):
    np.random.seed(exp_id)
    rlist_temp = np.random.uniform(low=0.0, high=1.0, size=K)
    rlist_temp[np.random.randint(low=0,high=K)] = xi + Delta 
    answer_set = list(np.where(rlist_temp > xi)[0] + 1)

    env = Environment_Gaussian(rlist=rlist_temp, K=K, random_seed=exp_id)
    agent_sas = Sticky_TaS(K=K, delta=delta, xi=xi, logC=1, log1_over_delta=1000)
    agent_sas_fast = Sticky_TaS_fast(K=K, delta=delta, xi=xi, logC=1, log1_over_delta=1000)
    agent_sas_old = Sticky_TaS_old(K=K, delta=delta, xi=xi, logC=1, log1_over_delta=1000)

    while (not agent_sas.stop) or (not agent_sas_fast.stop) or (not agent_sas_old.stop):
        arm_sas_fast = agent_sas_fast.action()
        arm_sas = agent_sas.action()
        arm_sas_old = agent_sas_old.action()
        assert arm_sas_fast == arm_sas, f"exp {exp_id}, round {agent_sas.t} inconsistent"
        assert arm_sas_old == arm_sas, f"exp {exp_id}, round {agent_sas.t} inconsistent"
        assert agent_sas.stop == agent_sas_fast.stop, f"exp {exp_id}, round {agent_sas.t} inconsistent"
        assert agent_sas_old.stop == agent_sas_fast.stop, f"exp {exp_id}, round {agent_sas.t} inconsistent"

        reward = env.response(arm_sas_fast)

        output_arm = agent_sas.observe(reward)
        output_arm_fast = agent_sas_fast.observe(reward)
        output_arm_old = agent_sas_old.observe(reward)
        
    assert output_arm == output_arm_fast, f"exp {exp_id}, output different arms"
    assert output_arm_old == output_arm_fast, f"exp {exp_id}, output different arms"

print("Didn't detect inconsistent action")

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [03:36<00:00,  2.16s/it]

Didn't detect inconsistent action





## Compare the execution speed of these three implementations

In the following, we will record the execution time of each implementation. Instead of passing $\delta$ to each algorithm, we directly pass $\log\frac{1}{\delta}$ to the algorithm. If $K$ is not large and $\delta$ is not small enough(below $\delta=\exp(-100)$), there isn't significant difference between the running speed of these implementations.

In [6]:
from env import Environment_Gaussian
from tqdm import tqdm
from time import time

K = 100
xi = 0.5
Delta = 4
rlist = np.zeros(K)
rlist[-1] = xi + Delta

delta = 0.0001 # useless here
n_exp = 100

result_dict = dict()
for alg_class in [Sticky_TaS_fast, Sticky_TaS, Sticky_TaS_old]:
    stop_time_ = np.zeros(n_exp)
    output_arm_ = list()
    correctness_ = np.ones(n_exp)
    exectution_time_ = np.zeros(n_exp)
    # for exp_id in tqdm(range(n_exp)):
    for exp_id in tqdm(range(n_exp)):
        rlist_temp = rlist[::-1].copy()
        # rlist_temp = rlist[::-1].copy()
        # np.random.seed(exp_id)
        # np.random.shuffle(rlist_temp)
        answer_set = list(np.where(rlist_temp > xi)[0] + 1)

        env = Environment_Gaussian(rlist=rlist_temp, K=K, random_seed=exp_id)
        agent = alg_class(K=K, delta=delta, xi=xi, logC=1, log1_over_delta=100)

        time_start = time()
        while not agent.stop:
            arm = agent.action()
            reward = env.response(arm)
            output_arm = agent.observe(reward)
            if output_arm is not None:
                output_arm_.append(output_arm)
                break
        time_end = time()
        stop_time_[exp_id] = agent.t
        exectution_time_[exp_id] = time_end - time_start
        if output_arm not in answer_set:
            correctness_[exp_id] = 0
    mean_stop_time = np.mean(stop_time_)
    mean_success = np.mean(correctness_)
    mean_execution_time = np.mean(exectution_time_)

    algname = type(agent).__name__
    result_dict[algname] = stop_time_
    print(f"For algorithm {algname}, ")
    print(f"mean stop time is {mean_stop_time}")
    print(f"correctness rate is {mean_success}")
    print(f"execution time is {mean_execution_time}")

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.47it/s]


For algorithm Sticky_TaS_fast, 
mean stop time is 601.14
correctness rate is 1.0
execution time is 0.07271615743637085


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00,  5.80it/s]


For algorithm Sticky_TaS, 
mean stop time is 601.14
correctness rate is 1.0
execution time is 0.1690439486503601


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:52<00:00,  1.92it/s]

For algorithm Sticky_TaS_old, 
mean stop time is 601.14
correctness rate is 1.0
execution time is 0.5171171426773071



