In [None]:
import numpy as np
import scipy.stats as stats
from scipy.special import gamma
import os
import time
import plotly.graph_objects as go
import plotly.colors
from numba import njit, prange
from pathlib import Path


In [None]:
# Parameters
NUM_ALGO = 2
MAX_T = 10000000
MAX_DF = 10000
NUM_QUANTILES = 99999
LIMIT = 100000000

# Precomputed values
sqrts = np.sqrt(np.arange(MAX_T))

# Load or compute sqare roots and quantiles
sqrts_path = f'tables/first_{MAX_T}_sqrts.npy'
n_quantiles_path = f'tables/n_quantiles_q{NUM_QUANTILES}.npy'
t_quantiles_path = f'tables/t_quantiles_df{MAX_DF}_q{NUM_QUANTILES}.npy'

# Ensure the 'data' directory exists
os.makedirs("data", exist_ok=True)

# Computes square roots
if not Path(sqrts_path).exists():
    sqrts = np.sqrt(np.arange(MAX_T))
    np.save(sqrts_path, sqrts)

# Computes normal and t distribution quantiles
if not Path(n_quantiles_path).exists() or not Path(t_quantiles_path).exists():
    
    # Possible quantiles for Thompson sampling
    quantiles = np.linspace(0.00001, 0.99999, NUM_QUANTILES).round(5)

    # Computes quantiles for normal and t distributions up to MAX_DF
    # after which normal approximation is used
    n_quantiles = stats.norm().ppf(quantiles)
    t_quantiles = np.array(
        [stats.t(df).ppf(quantiles) for df in range(1, MAX_DF + 1)], 
                  dtype=np.float32)

    # Saves the quantile tables
    np.save(n_quantiles_path, n_quantiles)
    np.save(t_quantiles_path, t_quantiles)

# Loads quantile tables and square roots
n_quantiles = np.load(n_quantiles_path)
t_quantiles = np.load(t_quantiles_path)
sqrts = np.load(sqrts_path)

In [None]:
# A numba implementation for np.random.choice using the CDF of probabilites
@njit
def weighted_choice(probabilities, size):
    """A numba implementation for np.random.choice using the CDF of probabilites

    Args:
        probabilities (np.array): An array of probabilities corresponding to the
         probability of choosing that integer index. Should sum to 1.
        size (int): The number of samples to choose from the probability vector

    Returns:
        np.array: A vector of length (size) of integers chosen with the given 
        probabilites
    """
    cumulative_sum = np.cumsum(probabilities)
    choices = np.searchsorted(cumulative_sum, np.random.rand(size))
    return choices

In [None]:
# Pre-determines rewards for each arm 
@njit
def compute_rewards(k, R, T):
    """Pre-computes rewards for each arm of the bandit at round t

    Args:
        k (int): The number of bandit arms.
        R (int): The range of reward.
        T (int): The number of rounds in the game.

    Returns:
        numpy vector: deltas : A length k vector containg the difference in mean 
        of each arm from the arm wih the highest mean.
        numpy matrix: arm_rewards: A k x T matrix that holds the rewards for 
        arm k at round T
    """

    # Randomizes reward probability for each arm
    p_vectors = np.random.dirichlet(np.ones(R), size=k)

    # Computes mean reward for each arm
    arm_true_means = np.dot(p_vectors, np.arange(R, dtype=np.float64))

    # Compute the difference in means from the best arm
    mu_star = arm_true_means.max()
    deltas = mu_star - arm_true_means
    
    # Pre determines rewards  for each arm 
    arm_rewards = np.empty((k, T), dtype=np.int32)
    for a in range(k):
        arm_rewards[a, :] = weighted_choice(p_vectors[a], T)

    return deltas, arm_rewards, arm_true_means


In [None]:
@njit
def run_single_simulation(k, R, T, n_quantiles = n_quantiles, 
                                   t_quantiles = t_quantiles, 
                                   sqrts = sqrts):
    """Runs a single simulation of a multi-arm bandit with NUM_ALGO different
    algorithms.

    Args:
        k (int): The number of arms.
        R (int): The range of reward.
        T (int): The number of rounds.
        n_quantiles (np.array): A vector of quantiles from the normal 
        distribution.
        t_quantiles (np.array): A matrix of quantiles from the t distribution. 
        Size MAX_DF x NUM_QUANTILES.
        sqrts (np.array): A vector of square roots of the indices.

    Returns:
        np.matrix: A NUM_ALGO x T matrix containing the cumulative regrets for 
        Thompson sampling using a normal and t distribution.
    """
    # Get the arm means and rewards
    deltas, arm_rewards, arm_true_means = compute_rewards(k, R, T)
    
    # Pre-sets quantiles for each round and arm
    arm_quantiles = np.random.randint(0, NUM_QUANTILES, (T, k))

    # Gets the quantiles for the normal distribution of each round and arm
    ts_n_quantiles = np.empty((T, k), dtype=n_quantiles.dtype)
    for t in range(T):
        for a in range(k):
            ts_n_quantiles[t, a] = n_quantiles[arm_quantiles[t, a]]

    # Initializes vectors to hold data during simulation
    ts_t_quantiles = np.empty(k) # T distribution quantiles
    arm_means = np.zeros((NUM_ALGO, k)) # Running arm means
    arm_stds = np.zeros((NUM_ALGO, k)) +.000001 # Running arm standard deviations
    n_t = np.zeros((NUM_ALGO, k), dtype=np.int32) # Number of arm pulls
    R_t = np.zeros((NUM_ALGO, T), dtype=np.float32) # Cumulative regrets
    regrets = np.empty(NUM_ALGO, dtype=np.float32) # Regret of current round
    
    # Starts the simulation
    for t in range(T):

        # Initializes array to hold actions for each algorithm in the round
        actions = np.zeros(NUM_ALGO, dtype=np.int32)
        
        # Pull each arm twice
        if t < 2 * k:
            actions[0] = t % k
            actions[1] = t % k
        
        # Then Thompson Sample
        else:

            # For each arm
            for a in range(k):

                # If it's been pulled less than MAX_DF, use a T distribution
                if n_t[1, a] <= MAX_DF + 1:
                    ts_t_quantiles[a] = (
                        t_quantiles[n_t[1, a] - 2, arm_quantiles[t, a]])

                # Otherwise use a normal approximation
                else:
                    ts_t_quantiles[a] = n_quantiles[arm_quantiles[t, a]]

            # Scales the normal and t distribution algorithms with mean and std
            adjusted_stds_n = arm_stds[0] / sqrts[n_t[0]]
            adjusted_stds_t = arm_stds[1] / sqrts[n_t[1]]
            
            scaled_n_quantiles = (arm_means[0] 
                                  + ts_n_quantiles[t] * adjusted_stds_n)
            scaled_t_quantiles = (arm_means[1] 
                                  + ts_t_quantiles * adjusted_stds_t)

            # Choose the action with the highest sampled mean
            actions[0] = np.argmax(scaled_n_quantiles)
            actions[1] = np.argmax(scaled_t_quantiles)

        # For each algorithm, update the mean and variance of the pulled arm
        for i in range(NUM_ALGO):

            # Store intermediate data
            action = actions[i]
            reward = arm_rewards[action, n_t[i, action]]
            regrets[i] = deltas[action]
            n_t[i, action] += 1
            current_mean = arm_means[i, action]

            # Intermediate calculations
            current_var = arm_stds[i, action] ** 2
            delta1 = reward - current_mean
            new_mean = current_mean + (delta1 / n_t[i, action])
            delta2 = reward - new_mean
            new_var = current_var + (delta1 * delta2 / n_t[i, action])

            # Update new mean and variance
            arm_means[i, action] = new_mean
            arm_stds[i, action] = np.sqrt(new_var)

        # Update the sequence of cumulative regrets
        if t > 0:
            R_t[:, t] = R_t[:, t - 1] + regrets
        else:
            R_t[:, t] = regrets
    
    
    return R_t


In [None]:
@njit(parallel=True)
def run_parallel_simulation(N, k, R, T, n_quantiles = n_quantiles, 
                                   t_quantiles = t_quantiles, 
                                   sqrts = sqrts):
    """Run N simulations of the multiarm bandits in parallel with Numba.

    Args:
        N (int): The number of simulations.
        k (int): The number of arms.
        R (int): The range of reward.
        T (int): The number of rounds.
        n_quantiles (np.array): A vector of quantiles from the normal 
        distribution.
        t_quantiles (np.array): A matrix of quantiles from the t distribution. 
        Size MAX_DF x NUM_QUANTILES.
        sqrts (np.array): A vector of square roots of the indices.

    Returns:
        np matrix: A NUM_ALGO x T x N matrix containing the cumulative regrets
        of each algorithm over T rounds over N simulations
    """

    results = np.empty((NUM_ALGO, T, N), dtype=np.float32)
    for n in prange(N):
        results[:, :, n] = run_single_simulation(
                                k, R, T, n_quantiles = n_quantiles, 
                                   t_quantiles = t_quantiles, 
                                   sqrts = sqrts)
    return results

In [None]:
def run_simulations(N, k, R, T, n_quantiles = n_quantiles, 
                                   t_quantiles = t_quantiles, 
                                   sqrts = sqrts):
    """Runs N simulations of the mulitarm bandits. Splits simulations into
    batches for better performance with high values of N and T.

    Args:
        N (int): The number of simulations.
        k (int): The number of arms.
        R (int): The range of reward.
        T (int): The number of rounds.
        n_quantiles (np.array): A vector of quantiles from the normal 
        distribution.
        t_quantiles (np.array): A matrix of quantiles from the t distribution. 
        Size MAX_DF x NUM_QUANTILES.
        sqrts (np.array): A vector of square roots of the indices.

    Returns:
        np matrix: A NUM_ALGO x T x N matrix containing the cumulative regrets
        of each algorithm over T rounds over N simulations
    """

    # Computes number of batches to run simulations in
    batches = N * T  // LIMIT

    start = time.time()

    if batches > 1:

        # Initialize arrays to hold running mean and variances of the data
        R_t_mean = np.zeros((NUM_ALGO, T))
        R_t_m2 = np.zeros((NUM_ALGO, T))  

        print(f"\nStarting {batches} batches.\n")

        # For each batch
        for i in range(batches):
            
            batch_start = time.time()
            
            # Run batch simulation and calculate mean and std for current batch
            R_t_all = run_parallel_simulation(
                        N // batches, k, R, T, n_quantiles = n_quantiles, 
                                   t_quantiles = t_quantiles, 
                                   sqrts = sqrts)
            batch_mean = R_t_all.mean(axis=2)
            batch_std = R_t_all.std(axis=2)
            batch_variance = batch_std**2

            # Update the combined mean and variance iteratively
            delta = batch_mean - R_t_mean
            R_t_mean += delta / (i + 1)  
            R_t_m2 += batch_variance + delta**2 * i / (i + 1)  

            # Print logging info
            batch_end = time.time()
            elapsed = round(batch_end - batch_start)
            remaining = round(((elapsed * batches)-(elapsed * (i + 1))) / 60, 2)
            print(f"Batch {i + 1} complete, {elapsed} seconds.")
            print(f"Remaining: {remaining} minutes.\n")

        # Calculate the final standard deviation from the accumulated R_t_m2
        R_t_std = np.sqrt(R_t_m2 / batches)

    # If one batch is needed, run the simulation as normal
    else:
        R_t_all = run_parallel_simulation(
                                N, k, R, T, n_quantiles = n_quantiles, 
                                   t_quantiles = t_quantiles, 
                                   sqrts = sqrts)
        R_t_mean = R_t_all.mean(axis=2)
        R_t_std = R_t_all.std(axis=2)

    # Save data
    np.save(f"data/k{k}_R{R}_N{N}_T{T}_Means", R_t_mean)
    np.save(f"data/k{k}_R{R}_N{N}_T{T}_Stds", R_t_std)

    # Print logging info
    end = time.time()
    elapsed = round((end - start) / 60, 2)
    print(f"Completion Time: {elapsed} minutes\n")

    return R_t_mean, R_t_std

In [None]:
@njit
def run_single_simulation_animation(k, R, T, n_quantiles = n_quantiles, 
                                   t_quantiles = t_quantiles, 
                                   sqrts = sqrts):
    """Runs a single simulation of a multi-arm bandit with NUM_ALGO different
    algorithms.

    Args:
        k (int): The number of arms.
        R (int): The range of reward.
        T (int): The number of rounds.
        n_quantiles (np.array): A vector of quantiles from the normal 
        distribution.
        t_quantiles (np.array): A matrix of quantiles from the t distribution. 
        Size MAX_DF x NUM_QUANTILES.
        sqrts (np.array): A vector of square roots of the indices.

    Returns:
        np.matrix: A NUM_ALGO x T matrix containing the cumulative regrets for 
        Thompson sampling using a normal and t distribution.
    """
    all_arm_means = np.empty((NUM_ALGO, k, T))
    all_arm_stds = np.empty((NUM_ALGO, k, T))

    # Get the arm means and rewards
    deltas, arm_rewards, arm_true_means = compute_rewards(k, R, T)
    
    # Pre-sets quantiles for each round and arm
    arm_quantiles = np.random.randint(0, NUM_QUANTILES, (T, k))

    # Gets the quantiles for the normal distribution of each round and arm
    ts_n_quantiles = np.empty((T, k), dtype=n_quantiles.dtype)
    for t in range(T):
        for a in range(k):
            ts_n_quantiles[t, a] = n_quantiles[arm_quantiles[t, a]]

    # Initializes vectors to hold data during simulation
    ts_t_quantiles = np.empty(k) # T distribution quantiles
    arm_means = np.zeros((NUM_ALGO, k)) # Running arm means
    arm_stds = np.zeros((NUM_ALGO, k)) +.001 # Running arm standard deviations
    n_t = np.zeros((NUM_ALGO, k), dtype=np.int32) # Number of arm pulls
    R_t = np.zeros((NUM_ALGO, T), dtype=np.float32) # Cumulative regrets
    regrets = np.empty(NUM_ALGO, dtype=np.float32) # Regret of current round
    
    # Starts the simulation
    for t in range(T):

        # Initializes array to hold actions for each algorithm in the round
        actions = np.zeros(NUM_ALGO, dtype=np.int32)
        
        # Pull each arm twice
        if t < 2 * k:
            actions[0] = t % k
            actions[1] = t % k
        
        # Then Thompson Sample
        else:

            # For each arm
            for a in range(k):

                # If it's been pulled less than MAX_DF, use a T distribution
                if n_t[1, a] <= MAX_DF + 1:
                    ts_t_quantiles[a] = (
                        t_quantiles[n_t[1, a] - 2, arm_quantiles[t, a]])

                # Otherwise use a normal approximation
                else:
                    ts_t_quantiles[a] = n_quantiles[arm_quantiles[t, a]]

            # Scales the normal and t distribution algorithms with mean and std
            adjusted_stds_n = arm_stds[0] / sqrts[n_t[0]]
            adjusted_stds_t = arm_stds[1] / sqrts[n_t[1]]
            scaled_n_quantiles = (arm_means[0] 
                                  + ts_n_quantiles[t] * adjusted_stds_n)
            scaled_t_quantiles = (arm_means[1] 
                                  + ts_t_quantiles * adjusted_stds_t)

            # Choose the action with the highest sampled mean
            actions[0] = np.argmax(scaled_n_quantiles)
            actions[1] = np.argmax(scaled_t_quantiles)

        # For each algorithm, update the mean and variance of the pulled arm
        for i in range(NUM_ALGO):

            # Store intermediate data
            action = actions[i]
            reward = arm_rewards[action, n_t[i, action]]
            regrets[i] = deltas[action]
            n_t[i, action] += 1
            current_mean = arm_means[i, action]

            # Intermediate calculations
            current_var = arm_stds[i, action] ** 2
            delta1 = reward - current_mean
            new_mean = current_mean + (delta1 / n_t[i, action])
            delta2 = reward - new_mean
            new_var = current_var + (delta1 * delta2 / n_t[i, action])

            # Update new mean and variance
            arm_means[i, action] = new_mean
            arm_stds[i, action] = np.sqrt(new_var)

            all_arm_means[i,:,t] = arm_means[i,:]
            all_arm_stds[i,:,t] = arm_stds[i,:] / sqrts[n_t[i, :]]

        # Update the sequence of cumulative regrets
        if t > 0:
            R_t[:, t] = R_t[:, t - 1] + regrets
        else:
            R_t[:, t] = regrets

    return R_t, all_arm_means, all_arm_stds, deltas, arm_true_means


### Estimating Mean and Variance of Regret by Monte Carlo 

In [None]:
# Runs several monte carlo simulations for combinations of arms and ranges
# Results are saved in the data folder
# Can take several hours to run

N = 10000
T = 1000000
ks = [3,4,5,6,7,8,9,10]
Rs = [10,100,1000,10000]

for k in ks:
    for R in Rs:
        run_simulations(N, k, R, T)

In [None]:
# Runs a single monte carlo simulation
# For high N and T, time-remaining info will be printed

N = 1000
T = 10000
k = 3
R = 10

R_t_mean, R_t_std = run_simulations(N, k, R, T)

In [None]:
# Plot Mean Results
fig_mean = go.Figure()
x_vals = np.arange(1,T+1)

fig_mean.add_trace(go.Scatter(
    x=x_vals, y=R_t_mean[0], mode = 'lines', name = 'Normal Distribution'))

fig_mean.add_trace(go.Scatter(
    x=x_vals, y=R_t_mean[1], mode = 'lines', name = 'T Distribution'))

fig_mean.update_layout(
    title=f"Mean Regret Across {N} Trials",
    title_x = 0.5,
    xaxis_title="Round",
    yaxis_title="Mean Regret",
    width=800,
    height=400,
    template="plotly",
    legend=dict(
        x=0.72,
        y=0.07,
        bgcolor="rgba(255, 255, 255, 0.25)" 
    )
)
fig_mean.show()

In [None]:
# Plot Variance Results
fig_var = go.Figure()
x_vals = np.arange(1,T+1)

fig_var.add_trace(go.Scatter(
    x=x_vals, y=R_t_std[0], mode = 'lines', name = 'Normal Distribution'))

fig_var.add_trace(go.Scatter(
    x=x_vals, y=R_t_std[1], mode = 'lines', name = 'T Distribution'))

fig_var.update_layout(
    title=f"Variance of Regret across {N} Trials",
    title_x = 0.5,
    xaxis_title="Round",
    yaxis_title="Regret Variance",
    width=800,
    height=400,
    template="plotly",
    legend=dict(
        x=0.72,
        y=0.07,
        bgcolor="rgba(255, 255, 255, 0.25)" 
    )
)
fig_var.show()

In [None]:
# Run a single simulation
N = 1000
T = 10000
k = 3
R = 10

R_t = run_single_simulation(k, R, T)

# Plot Results
fig_mean = go.Figure()
x_vals = np.arange(1,T+1)

fig_mean.add_trace(go.Scatter(x=x_vals, y=R_t[0], mode = 'lines', name = 'Normal Distribution',showlegend = True))
fig_mean.add_trace(go.Scatter(x=x_vals, y=R_t[1], mode = 'lines', name = 'T Distribution', showlegend = True))

fig_mean.update_layout(
    title="Regret of Thompson Sampling with Normal Posterior",
    title_x = 0.5,
    xaxis_title="Round",
    yaxis_title="Regret Mean",
    width=800,
    height=400,
    template="plotly",
    legend=dict(
        x=0.72,
        y=0.07,
        bgcolor="rgba(255, 255, 255, 0.25)" 
    )
)

fig_mean.show()


##### Get Data for Animation

In [None]:
# Runs and displays the results of the single trial
T = 10000
k = 3
R = 10

# Runs a single simulation to be animated
R_t, means, stds, deltas, arm_means = run_single_simulation_animation(k, R, T)

# Plot Results
fig_mean = go.Figure()
x_vals = np.arange(1,T+1)

fig_mean.add_trace(go.Scatter(
    x=x_vals, y=R_t[0], mode = 'lines', name = 'Normal Distribution'))

fig_mean.add_trace(go.Scatter(
    x=x_vals, y=R_t[1], mode = 'lines', name = 'T Distribution'))

fig_mean.update_layout(
    title="Mean Regret of Thompson Sampling Algorithms",
    title_x = 0.5,
    xaxis_title="Round",
    yaxis_title="Regret Mean",
    width=800,
    height=400,
    template="plotly",
    legend=dict(
        x=0.72,
        y=0.07,
        bgcolor="rgba(255, 255, 255, 0.25)" 
    )
)
fig_mean.show()

##### Animate the Normal Distribution

In [None]:
x = np.linspace(0, 11, 1000)
fig = go.Figure()

n_mus = means[0]
n_stds = stds[0]

colors = plotly.colors.DEFAULT_PLOTLY_COLORS[:3]

# Loop over reward vector to calculate distributions up to that point
frames = []
for t in range(5, 500):
    
    # Create traces for each distribution at the current time step
    traces = []

    for arm, arm_mean, color in zip([1, 2, 3], arm_means, colors):
        dashed_line = go.Scatter(
        x=[arm_mean, arm_mean],
        y=[0, .25],  # Extend the line over a reasonable y-range (adjust if needed)
        mode='lines',
        line=dict(dash='dash', width=2, color = color),
        name=f'Arm {arm} True Mean',
        showlegend=False
        )
        traces.append(dashed_line)

    for mu, std, arm, color in zip(n_mus[:, t], n_stds[:, t], [1, 2, 3], colors):
        y = (1 / (std * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / std) ** 2)
        trace = go.Scatter(x=x, y=y,line=dict(color=color), 
                           mode='lines', name=f'Arm {arm} N({mu:.2f}, {std**2:.2f})')
        traces.append(trace)
    
    # Add an annotation to display the current round
    annotation = go.layout.Annotation(
        text=f"Round: {t+1}",
        x=0.5, y=1.1, xref="paper", yref="paper", showarrow=False,
        font=dict(size=16)
    )

    # Append frame with current traces
    frames.append(go.Frame(data=traces, name=str(t), layout=go.Layout(annotations=[annotation])))

# Add traces for the first frame
for trace in frames[0].data:
    fig.add_trace(trace)

# Add the annotation for the first frame
initial_annotation = go.layout.Annotation(
    text=f"Round: 6",  # Adjust to match the first round value
    x=0.5, y=1.1, xref="paper", yref="paper", showarrow=False,
    font=dict(size=16)
)

# Configure the animation settings
fig.update(frames=frames)
fig.update_layout(
    annotations=[initial_annotation],
    title="Arm Mean Distributions - Normal",
    title_x= 0.4,
    xaxis_title="Mean Reward",
    yaxis_title="Density",
    legend=dict(
        x=0.75,
        y=0.97,
        bgcolor="rgba(255, 255, 255, 0.25)"),
    updatemenus=[dict(type="buttons", showactive=False,
                      buttons=[
                          dict(label="Play",
                               method="animate",
                               args=[None, {"frame": {"duration": 30, "redraw": True},
                                            "fromcurrent": True}]),
                          dict(label="Stop",
                               method="animate",
                               args=[[None], {"frame": {"duration": 0, "redraw": False},
                                              "mode": "immediate",
                                              "transition": {"duration": 0}}])
                      ])]
)

# Show the animated plot
fig.show()

##### Animate the T distribution

In [None]:
# Initialize a matrix to track the number of pulls
dfs = np.zeros_like(means[1])

# Iterate over columns to compute pulls
for i in range(1, means[1].shape[1]):
    # Compare the current column with the previous column
    dfs[:, i] = dfs[:, i - 1] + (means[1][:, i] != means[1][:, i - 1])

x = np.linspace(0, 11, 1000)
fig = go.Figure()

n_mus = means[1]
n_stds = stds[1]

colors = plotly.colors.DEFAULT_PLOTLY_COLORS[:3]

def t_distribution_pdf(x, mu, sigma, df):
    # Compute the scaling factor
    scale_factor = gamma((df + 1) / 2) / (np.sqrt(df * np.pi) * sigma * gamma(df / 2))
    # Compute the PDF
    y = scale_factor * (1 + ((x - mu) ** 2) / (df * sigma ** 2)) ** (-(df + 1) / 2)
    return y

# Loop over reward vector to calculate distributions up to that point
frames = []
for t in range(5, 500):
    
    # Create traces for each distribution at the current time step
    traces = []

    for arm, arm_mean, color in zip([1, 2, 3], arm_means, colors):
        dashed_line = go.Scatter(
        x=[arm_mean, arm_mean],
        y=[0, .25],  # Extend the line over a reasonable y-range (adjust if needed)
        mode='lines',
        line=dict(dash='dash', width=2, color = color),
        name=f'Arm {arm} True Mean',
        showlegend=False
        )
        traces.append(dashed_line)

    for mu, std, arm, df, color in zip(n_mus[:, t], n_stds[:, t], [1, 2, 3], dfs[:,t], colors):
        y = t_distribution_pdf(x,mu,std,df)
        trace = go.Scatter(x=x, y=y,line=dict(color=color), 
                           mode='lines', name=f'Arm {arm} N({mu:.2f}, {std**2:.2f})')
        traces.append(trace)
    
    # Add an annotation to display the current round
    annotation = go.layout.Annotation(
        text=f"Round: {t}",
        x=0.5, y=1.1, xref="paper", yref="paper", showarrow=False,
        font=dict(size=16)
    )

    # Append frame with current traces
    frames.append(go.Frame(data=traces, name=str(t), layout=go.Layout(annotations=[annotation])))

# Add traces for the first frame
for trace in frames[0].data:
    fig.add_trace(trace)

# Add the annotation for the first frame
initial_annotation = go.layout.Annotation(
    text=f"Round: 6",  # Adjust to match the first round value
    x=0.5, y=1.1, xref="paper", yref="paper", showarrow=False,
    font=dict(size=16)
)

# Configure the animation settings
fig.update(frames=frames)
fig.update_layout(
    annotations=[initial_annotation],
    title="Arm Mean Distributions - t",
    title_x= 0.4,
    xaxis_title="Mean Reward",
    yaxis_title="Density",
    legend=dict(
        x=0.75,
        y=0.97,
        bgcolor="rgba(255, 255, 255, 0.25)"),
    updatemenus=[dict(type="buttons", showactive=False,
                      buttons=[
                          dict(label="Play",
                               method="animate",
                               args=[None, {"frame": {"duration": 30, "redraw": True},
                                            "fromcurrent": True}]),
                          dict(label="Stop",
                               method="animate",
                               args=[[None], {"frame": {"duration": 0, "redraw": False},
                                              "mode": "immediate",
                                              "transition": {"duration": 0}}])
                      ])]
)

# Show the animated plot
fig.show()