In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

In [3]:
class MultiArmedBandit:
    def __init__(self, K):
        self.K = K
        self.true_means = np.random.rand(K)

    def pull(self, arm):
        return np.random.binomial(1, self.true_means[arm])

In [4]:
def naive_algorithm(bandit, epsilon, delta):
    K = bandit.K
    n_samples = int((2 / (epsilon ** 2)) * np.log(2 * K / delta))
    
    empirical_means = np.zeros(K)
    total_samples = 0

    for arm in range(K):
        rewards = [bandit.pull(arm) for _ in range(n_samples)]
        empirical_means[arm] = np.mean(rewards)
        total_samples += n_samples

    chosen_arm = np.argmax(empirical_means)
    
    return chosen_arm, empirical_means[chosen_arm], total_samples

In [5]:
def median_elimination(bandit, epsilon, delta):
    arms = list(range(bandit.K))
    epsilon_l = epsilon / 4
    delta_l = delta / 2
    
    total_samples = 0
    phase_history = []
    empirical_means = np.zeros(bandit.K)

    while len(arms) > 1:
        n_samples = int((4 / (epsilon_l ** 2)) * np.log(3 / delta_l))
        
        means = []
        for arm in arms:
            rewards = [bandit.pull(arm) for _ in range(n_samples)]
            empirical_means[arm] = np.mean(rewards)
            means.append(empirical_means[arm])
        
        total_samples += n_samples * len(arms)
        
        median_value = np.median(means)
        new_arms = [arm for arm in arms if empirical_means[arm] >= median_value]
        
        phase_history.append(len(arms) - len(new_arms))
        
        arms = new_arms
        epsilon_l *= 0.75
        delta_l *= 0.5

    return arms[0], empirical_means[arms[0]], total_samples, phase_history

In [None]:
K = 1000
epsilon = 0.05
delta = 0.05
runs = 100

naive_success = 0
mea_success = 0

for _ in tqdm(range(runs)):
    bandit = MultiArmedBandit(K)
    true_best = np.max(bandit.true_means)

    arm_n, _, _ = naive_algorithm(bandit, epsilon, delta)
    arm_m, _, _, _ = median_elimination(bandit, epsilon, delta)

    if bandit.true_means[arm_n] >= true_best - epsilon:
        naive_success += 1
        
    if bandit.true_means[arm_m] >= true_best - epsilon:
        mea_success += 1

print("Naive ε-optimal count:", naive_success)
print("MEA ε-optimal count:", mea_success)

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
epsilon = 0.05
delta = 0.05

arms_list = list(range(100, 100001, 10000))

naive_samples = []
mea_samples = []

for K in tqdm(arms_list):
    bandit = MultiArmedBandit(K)
    
    _, _, s_n = naive_algorithm(bandit, epsilon, delta)
    _, _, s_m, _ = median_elimination(bandit, epsilon, delta)
    
    naive_samples.append(s_n)
    mea_samples.append(s_m)

plt.figure(figsize=(8,6))
plt.plot(arms_list, naive_samples, label="Naive")
plt.plot(arms_list, mea_samples, label="MEA")
plt.xlabel("Number of Arms (K)")
plt.ylabel("Sample Complexity")
plt.title("Sample Complexity vs Number of Arms")
plt.legend()
plt.show()

In [None]:
K = 500000
bandit = MultiArmedBandit(K)

epsilon_values = [0.1, 0.05, 0.01]
delta_values = [0.1, 0.05, 0.01]

true_best = np.max(bandit.true_means)

# --- Vary epsilon ---
naive_errors_eps = []
mea_errors_eps = []

for epsilon in epsilon_values:
    arm_n, emp_n, _ = naive_algorithm(bandit, epsilon, 0.05)
    arm_m, emp_m, _, _ = median_elimination(bandit, epsilon, 0.05)

    naive_errors_eps.append(abs(emp_n - true_best))
    mea_errors_eps.append(abs(emp_m - true_best))

plt.figure()
plt.plot(epsilon_values, naive_errors_eps, label="Naive")
plt.plot(epsilon_values, mea_errors_eps, label="MEA")
plt.xlabel("Epsilon")
plt.ylabel("Empirical Mean Error")
plt.title("Error vs Epsilon")
plt.legend()
plt.show()

# --- Vary delta ---
naive_errors_delta = []
mea_errors_delta = []

for delta in delta_values:
    arm_n, emp_n, _ = naive_algorithm(bandit, 0.05, delta)
    arm_m, emp_m, _, _ = median_elimination(bandit, 0.05, delta)

    naive_errors_delta.append(abs(emp_n - true_best))
    mea_errors_delta.append(abs(emp_m - true_best))

plt.figure()
plt.plot(delta_values, naive_errors_delta, label="Naive")
plt.plot(delta_values, mea_errors_delta, label="MEA")
plt.xlabel("Delta")
plt.ylabel("Empirical Mean Error")
plt.title("Error vs Delta")
plt.legend()
plt.show()

In [None]:
K = 1000
epsilon = 0.05
delta = 0.05

bandit = MultiArmedBandit(K)
_, _, _, phase_history = median_elimination(bandit, epsilon, delta)

plt.figure()
plt.plot(range(1, len(phase_history)+1), phase_history)
plt.xlabel("Phase")
plt.ylabel("Number of Arms Removed")
plt.title("MEA Phase Elimination History")
plt.show()