In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets, Layout, HBox, VBox, Output, IntSlider
from IPython.display import display, clear_output

# Ensure that matplotlib plots are rendered inline in the notebook
%matplotlib inline

class UCB1Visualizer:
    def __init__(self):
        # Initialize widgets
        self.num_arms_input = widgets.Text(
            value='3',
            description='Number of Arms:',
            layout=Layout(width='200px')
        )
        
        self.means_input = widgets.Text(
            value='0.2,0.4,0.8',
            description='Means of Arms:',
            layout=Layout(width='400px')
        )
        
        self.num_rounds_input = widgets.Text(
            value='1000',
            description='Number of Rounds:',
            layout=Layout(width='200px')
        )
        
        self.run_button = widgets.Button(
            description='Run Simulation',
            button_style='success'
        )
        
        self.prev_button = widgets.Button(
            description='Previous Round',
            disabled=True
        )
        
        self.next_button = widgets.Button(
            description='Next Round',
            disabled=True
        )
        
        self.round_info = widgets.HTML(
            value="<b>Round:</b> 0"
        )
        
        # New widget: Slider for selecting the round
        self.round_slider = widgets.IntSlider(
            value=0,
            min=0,
            max=100,
            step=1,
            description='Select Round:',
            disabled=True,
            continuous_update=False,
            layout=Layout(width='600px')
        )
        
        # New widget for displaying arm and reward information
        self.arm_reward_info = widgets.HTML(
            value="<b>Arm Pulled:</b> N/A<br><b>Reward Observed:</b> N/A"
        )
        
        self.counts_info = widgets.HTML(
            value="<b>Arm Counts:</b> N/A"
        )
        
        self.output = Output()
        
        # Layout setup
        input_box = VBox([
            HBox([self.num_arms_input, self.means_input, self.num_rounds_input]),
            self.run_button
        ])
        
        nav_box = HBox([
            self.prev_button,
            self.next_button,
            self.round_info
        ])
        
        slider_box = VBox([
            self.round_slider
        ])
        
        info_box = VBox([
            self.arm_reward_info,
            self.counts_info
        ])
        
        # Combine all UI components
        self.ui = VBox([input_box, nav_box, slider_box, self.output, info_box])
        
        # Event handlers
        self.run_button.on_click(self.run_simulation)
        self.prev_button.on_click(self.prev_round)
        self.next_button.on_click(self.next_round)
        self.round_slider.observe(self.slider_changed, names='value')
        
        # Initialize simulation variables
        self.num_arms = 0
        self.means = []
        self.num_rounds = 0
        self.simulation_done = False
        self.current_round = 0
        self.selections = []
        self.rewards = []
        self.counts = []
        self.values = []
        self.cumulative_rewards = []
        self.conf_radii = []
        self.cumulative_regret = []
        self.best_mean = 0.0
        
        display(self.ui)
    
    def run_simulation(self, b):
        self.rng = np.random.default_rng(1234)
        with self.output:
            clear_output(wait=True)
            # Parse inputs
            try:
                self.num_arms = int(self.num_arms_input.value)
                self.means = [float(m.strip()) for m in self.means_input.value.split(',')]
                self.num_rounds = int(self.num_rounds_input.value)
                
                if len(self.means) != self.num_arms:
                    raise ValueError("Number of means must match the number of arms.")
                
                if any(m < 0 or m > 1 for m in self.means):
                    raise ValueError("Means should be between 0 and 1.")
                
                if self.num_rounds < self.num_arms:
                    raise ValueError("Number of rounds should be at least equal to the number of arms.")
                
            except Exception as e:
                print(f"Input Error: {e}")
                return
            
            # Identify the best mean for regret calculation
            self.best_mean = max(self.means)
            
            # Initialize UCB1 variables
            self.selections = []
            self.rewards = []
            self.counts = [0] * self.num_arms
            self.values = [0.0] * self.num_arms
            self.cumulative_rewards = []
            self.conf_radii = []
            self.cumulative_regret = []
            total_reward = 0.0
            total_mean = 0.0
            
            # Precompute all rounds
            for round in range(1, self.num_rounds + 1):
                if round <= self.num_arms:
                    # Select each arm once in the first N rounds
                    arm = round - 1
                else:
                    # Compute UCB1 for each arm
                    ucb_values = []
                    for i in range(self.num_arms):
                        if self.counts[i] == 0:
                            # Assign infinity to ensure each arm is selected at least once
                            ucb = float('inf')
                        else:
                            exploration = np.sqrt((2 * np.log(round)) / self.counts[i])
                            ucb = self.values[i] + exploration
                        ucb_values.append(ucb)
                    arm = np.argmax(ucb_values)
                
                # Simulate reward
                reward = self.rng.binomial(1, self.means[arm])
                
                # Update counts and values
                self.counts[arm] += 1
                n = self.counts[arm]
                value = self.values[arm]
                # Incremental update of average
                self.values[arm] += (reward - value) / n
                
                # Update selections and rewards
                self.selections.append(arm)
                self.rewards.append(reward)
                total_reward += reward
                self.cumulative_rewards.append(total_reward)
                
                # Compute confidence radius for each arm at this round
                current_conf_radii = []
                for i in range(self.num_arms):
                    if self.counts[i] > 0:
                        conf_radius = np.sqrt((2 * np.log(round)) / self.counts[i])
                    else:
                        conf_radius = float('inf')  # Shouldn't happen as each arm is selected at least once
                    current_conf_radii.append(conf_radius)
                self.conf_radii.append(current_conf_radii)
                
                # Compute cumulative regret
                expected_best = self.best_mean * round
                total_mean += self.means[arm]
                regret = expected_best - total_mean
                self.cumulative_regret.append(regret)
            
            # Reset current round
            self.current_round = 0
            self.update_buttons()
            self.update_slider()
            self.update_plot()
            print("Simulation completed!")
    
    def prev_round(self, b):
        if self.current_round > 0:
            self.current_round -= 1
            self.update_buttons()
            self.update_slider()
            self.update_plot()
    
    def next_round(self, b):
        if self.current_round < self.num_rounds:
            self.current_round += 1
            self.update_buttons()
            self.update_slider()
            self.update_plot()
    
    def slider_changed(self, change):
        new_round = change['new']
        if new_round != self.current_round:
            self.current_round = new_round
            self.update_buttons()
            self.update_plot()
    
    def update_buttons(self):
        # Enable or disable navigation buttons based on current_round
        self.prev_button.disabled = self.current_round == 0
        self.next_button.disabled = self.current_round == self.num_rounds
        self.round_info.value = f"<b>Round:</b> {self.current_round}"
    
    def update_slider(self):
        # Update slider range and enable it
        self.round_slider.max = self.num_rounds
        self.round_slider.min = 0
        self.round_slider.step = 1
        self.round_slider.value = self.current_round
        self.round_slider.disabled = False
    
    def update_plot(self):
        with self.output:
            clear_output(wait=True)
            if self.current_round == 0:
                print("Run the simulation and navigate through the rounds.")
                # Reset arm and reward info
                self.arm_reward_info.value = "<b>Arm Pulled:</b> N/A<br><b>Reward Observed:</b> N/A"
                return
            
            # Data up to current_round
            selections = self.selections[:self.current_round]
            rewards = self.rewards[:self.current_round]
            counts = [0] * self.num_arms
            values = [0.0] * self.num_arms
            for i in range(len(selections)):
                arm = selections[i]
                reward = rewards[i]
                counts[arm] += 1
                values[arm] += (reward - values[arm]) / counts[arm]
            
            avg_rewards = [values[i] if counts[i] > 0 else 0 for i in range(self.num_arms)]
            cumulative_reward = self.cumulative_rewards[self.current_round - 1]
            current_conf_radii = self.conf_radii[self.current_round - 1]
            current_regret = self.cumulative_regret[self.current_round - 1]
            
            # Get current round's arm and reward
            current_arm = self.selections[self.current_round - 1]
            current_reward = self.rewards[self.current_round - 1]
            
            # Update arm and reward information
            self.arm_reward_info.value = f"<b>Round {self.current_round}:</b> Pulled <b>Arm {current_arm}</b>, Observed <b>Reward {current_reward}</b>"
            
            # Update counts information
            counts_text = "<b>Arm Counts:</b><br>"
            for i in range(self.num_arms):
                counts_text += f"Arm {i}: {counts[i]} times<br>"
            self.counts_info.value = counts_text
            
            # Calculate cumulative regret
            cumulative_regret = self.cumulative_regret[self.current_round - 1]
            
            # Set up subplots: 1 row, 2 columns
            fig, axs = plt.subplots(2, 1, figsize=(8, 8))
            
            # Define font sizes
            title_font = 20
            label_font = 16
            tick_font = 14
            annotation_font = 12
            
            # 1. Bar chart for average rewards with confidence radius
            bar_positions = np.arange(self.num_arms)
            bars = axs[0].bar(bar_positions, avg_rewards, color='skyblue', yerr=current_conf_radii, capsize=5, alpha=0.7)
            axs[0].set_ylim(bottom=0)
            axs[0].set_xlabel('Arms', fontsize=label_font)
            axs[0].set_ylabel('Average Reward', fontsize=label_font)
            axs[0].set_title(f'Average Rewards up to Round {self.current_round}', fontsize=title_font)
            axs[0].set_xticks(bar_positions)
            axs[0].set_xticklabels([f'Arm {i}' for i in range(self.num_arms)])
            
            # Annotate average rewards
            for bar, avg in zip(bars, avg_rewards):
                height = bar.get_height()
                axs[0].text(bar.get_x() + bar.get_width()/2., height + 0.02,
                            f'{avg:.2f}', ha='center', va='bottom', fontsize=annotation_font)
            
            # Annotate confidence radius values at the top of error bars
            for i, bar in enumerate(bars):
                avg = avg_rewards[i]
                conf = current_conf_radii[i]
                y = avg + conf
                # Ensure that the text is within the plot's y-axis
                if y > axs[0].get_ylim()[1]:
                    y = axs[0].get_ylim()[1] - 0.05
                axs[0].text(bar.get_x() + bar.get_width()/2., y + 0.02,
                            f'{avg_rewards[i]+conf:.2f}', ha='center', va='bottom', color='black', fontsize=annotation_font)
            
            # 2. Cumulative Regret Plot
            axs[1].plot(range(1, self.current_round + 1), self.cumulative_regret[:self.current_round], color='red', label='Cumulative Regret')
            axs[1].set_xlabel('Round', fontsize=label_font)
            axs[1].set_ylabel('Cumulative Regret', fontsize=label_font)
            axs[1].set_title(f'Cumulative Regret up to Round {self.current_round}', fontsize=title_font)
            axs[1].grid(True)
            axs[1].legend()
            
            axs[0].set_xticklabels([f'Arm {i}' for i in range(self.num_arms)], fontsize=tick_font)
            axs[0].tick_params(axis='y', labelsize=tick_font)
            axs[1].tick_params(axis='both', labelsize=tick_font)

            
            plt.tight_layout()
            plt.show()

# Instantiate and display the visualizer
UCB1Visualizer();