<a href="https://colab.research.google.com/github/noobylub/APCSA/blob/master/Copy_of_MAB_RL_challenge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from collections import deque
import random

In [None]:
# @title
# Scaffolding
class MultiArmedBandit:
    """
    Multi-armed bandit with Gaussian payout distributions.

    Each arm has an independent Gaussian distribution with specified mean and variance.
    The bandit tracks the history of arm pulls and can provide statistics for learning algorithms.
    """

    def __init__(
        self,
        n_arms: int = 10,
        means: Optional[Union[List[float], np.ndarray, torch.Tensor]] = None,
        variances: Optional[Union[List[float], np.ndarray, torch.Tensor]] = None,
        random_seed: Optional[int] = None,
    ):
        """
        Initialize the multi-armed bandit.

        Args:
            n_arms: Number of arms in the bandit (default: 10)
            means: Mean payout for each arm. If None, defaults to 0 for all arms.
                  Can be list, numpy array, or torch tensor.
            variances: Variance of payout for each arm. If None, defaults to 1 for all arms.
                      Can be list, numpy array, or torch tensor.
            random_seed: Random seed for reproducible results (default: None)
        """
        self.n_arms = n_arms

        # Set random seed if provided
        if random_seed is not None:
            torch.manual_seed(random_seed)
            np.random.seed(random_seed)

        # Initialize means
        if means is None:
            self.means = torch.zeros(n_arms)
        else:
            self.means = torch.tensor(means, dtype=torch.float32)
            if len(self.means) != n_arms:
                raise ValueError(
                    f"Length of means ({len(self.means)}) must equal n_arms ({n_arms})"
                )

        # Initialize variances
        # Each time you pull an arm, it tells the regret, compare the value you get with the beset possible value
        if variances is None:
            self.variances = torch.ones(n_arms)
        else:
            self.variances = torch.tensor(variances, dtype=torch.float32)
            if len(self.variances) != n_arms:
                raise ValueError(
                    f"Length of variances ({len(self.variances)}) must equal n_arms ({n_arms})"
                )
            if torch.any(self.variances <= 0):
                raise ValueError("All variances must be positive")

        # Compute standard deviations for efficiency
        self.std_devs = torch.sqrt(self.variances)

        # Initialize tracking variables
        self.reset_history()

    def reset_history(self):
        """Reset all tracking variables to initial state."""
        self.pull_counts = torch.zeros(self.n_arms, dtype=torch.long)
        self.total_rewards = torch.zeros(self.n_arms)
        self.sum_squared_rewards = torch.zeros(self.n_arms)
        self.pull_history = []  # List of (arm, reward) tuples
        self.total_pulls = 0

    def pull_arm(self, arm: int) -> float:
        """
        Pull a specific arm and get a reward sample.

        Args:
            arm: Index of the arm to pull (0 to n_arms-1)

        Returns:
            Reward sampled from the arm's Gaussian distribution

        Raises:
            ValueError: If arm index is invalid
        """
        if not (0 <= arm < self.n_arms):
            raise ValueError(f"Arm index {arm} is out of range [0, {self.n_arms - 1}]")

        # Sample reward from Gaussian distribution
        reward = torch.normal(self.means[arm], self.std_devs[arm]).item()

        # Update tracking variables
        self.pull_counts[arm] += 1
        self.total_rewards[arm] += reward
        self.sum_squared_rewards[arm] += reward**2
        self.pull_history.append((arm, reward))
        self.total_pulls += 1

        return reward

    def get_arm_statistics(self, arm: int) -> dict:
        """
        Get statistics for a specific arm.

        Args:
            arm: Index of the arm

        Returns:
            Dictionary containing statistics for the arm
        """
        if not (0 <= arm < self.n_arms):
            raise ValueError(f"Arm index {arm} is out of range [0, {self.n_arms - 1}]")

        count = self.pull_counts[arm].item()

        if count == 0:
            return {
                "count": 0,
                "mean_reward": 0.0,
                "std_reward": 0.0,
                "true_mean": self.means[arm].item(),
                "true_std": self.std_devs[arm].item(),
            }

        total_reward = self.total_rewards[arm].item()
        sum_squared = self.sum_squared_rewards[arm].item()

        mean_reward = total_reward / count

        # Calculate sample standard deviation
        if count > 1:
            variance_sample = (sum_squared - count * mean_reward**2) / (count - 1)
            std_reward = np.sqrt(max(0, variance_sample))  # Ensure non-negative
        else:
            std_reward = 0.0

        return {
            "count": count,
            "mean_reward": mean_reward,
            "std_reward": std_reward,
            "true_mean": self.means[arm].item(),
            "true_std": self.std_devs[arm].item(),
        }

    def get_all_statistics(self) -> dict:
        """
        Get statistics for all arms.

        Returns:
            Dictionary with overall statistics and per-arm statistics
        """
        arm_stats = [self.get_arm_statistics(i) for i in range(self.n_arms)]

        return {
            "total_pulls": self.total_pulls,
            "arms": arm_stats,
            "overall_mean_reward": (
                self.total_rewards.sum().item() / max(1, self.total_pulls)
            ),
        }

    def get_current_state(self) -> torch.Tensor:
        """
        Get current state suitable for neural network input.

        Returns state in the format expected by MultiArmedBanditNet:
        [usage_count_0, avg_payout_0, std_0, usage_count_1, avg_payout_1, std_1, ...]

        Returns:
            Tensor of shape (3 * n_arms,) containing the current state
        """
        state = torch.zeros(3 * self.n_arms)

        for arm in range(self.n_arms):
            stats = self.get_arm_statistics(arm)
            state[3 * arm] = stats["count"]  # Usage count
            state[3 * arm + 1] = stats["mean_reward"]  # Average payout
            state[3 * arm + 2] = stats["std_reward"]  # Standard deviation

        return state

    def get_optimal_arm(self) -> int:
        """
        Get the arm with the highest true mean (optimal arm for this bandit).

        Returns:
            Index of the optimal arm
        """
        return torch.argmax(self.means).item()

    def get_regret(self, arm: int) -> float:
        """
        Calculate the instantaneous regret for pulling a specific arm.

        Args:
            arm: Index of the pulled arm

        Returns:
            Regret (difference between optimal and chosen arm's true mean)
        """
        optimal_mean = self.means[self.get_optimal_arm()]
        chosen_mean = self.means[arm]
        return (optimal_mean - chosen_mean).item()

    def get_cumulative_regret(self) -> float:
        """
        Calculate cumulative regret over all pulls.

        Returns:
            Total regret accumulated over all arm pulls
        """
        optimal_mean = self.means[self.get_optimal_arm()]
        cumulative_regret = 0.0

        for arm, _ in self.pull_history:
            cumulative_regret += self.get_regret(arm)

        return cumulative_regret

    def simulate_batch_pulls(self, arms: List[int]) -> List[float]:
        """
        Pull multiple arms in sequence and return all rewards.

        Args:
            arms: List of arm indices to pull

        Returns:
            List of rewards from each pull
        """
        rewards = []
        for arm in arms:
            reward = self.pull_arm(arm)
            rewards.append(reward)
        return rewards

    def __str__(self) -> str:
        """String representation of the bandit."""
        return (
            f"MultiArmedBandit(n_arms={self.n_arms}, "
            f"total_pulls={self.total_pulls})\n"
            f"True means: {self.means.tolist()}\n"
            f"True stds: {self.std_devs.tolist()}"
        )

    def __repr__(self) -> str:
        """Detailed representation of the bandit."""
        return self.__str__()

In [None]:
# Neural network
class MultiArmedBanditNet(nn.Module):
    """
    Neural network for multi-armed bandit action selection with parallel arm processing.

    Input: 3 * n_arms features (for each arm: usage count, average payout, standard deviation)
    Output: n_arms + 1 logits (one for each arm + stopping action)
    """

    def __init__(
        self,
        n_arms,
        max_actions,
        dim_arm=64,
        hidden_layers=None,
        dropout_rate=0.1,
        activation="relu",
    ):
        super(MultiArmedBanditNet, self).__init__()

        self.n_arms = n_arms
        self.max_actions = max_actions
        self.dim_arm = dim_arm

        if hidden_layers is None:
            hidden_layers = [64, 32]

        arm_input_dim = 3

        if activation == "relu":
            act_module = nn.ReLU()
        elif activation == "tanh":
            act_module = nn.Tanh()
        elif activation == "leaky_relu":
            act_module = nn.LeakyReLU(negative_slope=0.01)
        else:
            raise ValueError("Unsupported activation function: {}".format(activation))

        arm_embed_layers = [
            nn.Linear(arm_input_dim, dim_arm),
            act_module,
            nn.BatchNorm1d(dim_arm),
            nn.Dropout(dropout_rate),
            nn.Linear(dim_arm, dim_arm),
            act_module,
            nn.BatchNorm1d(dim_arm),
            nn.Dropout(dropout_rate),
        ]
        self.arm_embedder = nn.Sequential(*arm_embed_layers)

        main_layers = []
        prev_dim = dim_arm
        for hidden_dim in hidden_layers:
            main_layers.extend(
                [
                    nn.Linear(prev_dim, hidden_dim),
                    act_module,
                    nn.BatchNorm1d(hidden_dim),
                    nn.Dropout(dropout_rate),
                ]
            )
            prev_dim = hidden_dim
        main_layers.append(nn.Linear(prev_dim, 1))
        self.arm_logit_mlp = nn.Sequential(*main_layers)

        context_dim = 3
        stop_layers = [
            nn.Linear(context_dim, 32),
            act_module,
            nn.Dropout(dropout_rate),
            nn.Linear(32, 16),
            act_module,
            nn.Dropout(dropout_rate),
            nn.Linear(16, 1),
        ]
        self.stop_logit_mlp = nn.Sequential(*stop_layers)

        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)

    def _embed_arms(self, arm_features):
        # arm_features: (batch_size, n_arms, 3)
        batch_size, n_arms, feature_dim = arm_features.shape
        flattened_arms = arm_features.reshape(-1, feature_dim)
        arm_embeddings = self.arm_embedder(flattened_arms)
        arm_embeddings = arm_embeddings.reshape(batch_size, n_arms, self.dim_arm)
        return arm_embeddings

    def _compute_arm_logits(self, arm_embeddings):
        # arm_embeddings: (batch_size, n_arms, dim_arm)
        batch_size, n_arms, dim_arm = arm_embeddings.shape
        flattened_embeddings = arm_embeddings.reshape(-1, dim_arm)
        arm_logits_flat = self.arm_logit_mlp(flattened_embeddings)
        arm_logits = arm_logits_flat.reshape(batch_size, n_arms)
        return arm_logits

    def _compute_stop_logit(self, arm_features):
        # arm_features: (batch_size, n_arms, 3)
        # context: total_actions (normalized), weighted_avg_reward, arms_used_ratio
        usage_counts = arm_features[:, :, 0]
        avg_payouts = arm_features[:, :, 1]

        total_actions = torch.sum(usage_counts, dim=1, keepdim=True) / max(
            1.0, float(self.max_actions)
        )
        total_usage = torch.sum(usage_counts, dim=1, keepdim=True)
        total_usage = torch.clamp(total_usage, min=1e-8)
        weighted_avg_reward = (
            torch.sum(usage_counts * avg_payouts, dim=1, keepdim=True) / total_usage
        )
        arms_used = (usage_counts > 0).float()
        arms_used_ratio = torch.sum(arms_used, dim=1, keepdim=True) / float(self.n_arms)

        context = torch.cat(
            [total_actions, weighted_avg_reward, arms_used_ratio], dim=1
        )
        stop_logit = self.stop_logit_mlp(context)
        return stop_logit

    def forward(self, x):
        # x: (batch_size, 3 * n_arms)
        expected_input_dim = 3 * self.n_arms
        if x.shape[-1] != expected_input_dim:
            raise ValueError(
                "Expected input dimension {}, got {}".format(
                    expected_input_dim, x.shape[-1]
                )
            )

        batch_size = x.shape[0]
        arm_features = x.reshape(batch_size, self.n_arms, 3)

        normalized_arm_features = arm_features.clone()
        normalized_arm_features[:, :, 0] = normalized_arm_features[:, :, 0] / float(
            self.max_actions
        )

        arm_embeddings = self._embed_arms(normalized_arm_features)
        arm_logits = self._compute_arm_logits(arm_embeddings)
        stop_logit = self._compute_stop_logit(arm_features)

        all_logits = torch.cat([arm_logits, stop_logit], dim=1)
        return all_logits

    def get_action_probabilities(self, x, temperature=1.0):
        logits = self.forward(x)
        probabilities = F.softmax(logits / float(temperature), dim=-1)
        return probabilities

    def select_action(self, x, temperature=1.0, deterministic=False):
        with torch.no_grad():
            probabilities = self.get_action_probabilities(x, temperature)
            if deterministic:
                actions = torch.argmax(probabilities, dim=-1)
            else:
                # multinomial expects probabilities to sum to 1 across last dimension
                actions = torch.multinomial(probabilities, num_samples=1).squeeze(-1)
            return actions

    def get_arm_embeddings(self, x):
        batch_size = x.shape[0]
        arm_features = x.reshape(batch_size, self.n_arms, 3)
        normalized_arm_features = arm_features.clone()
        normalized_arm_features[:, :, 0] = normalized_arm_features[:, :, 0] / float(
            self.max_actions
        )
        return self._embed_arms(normalized_arm_features)

    def get_arm_logits_separate(self, x):
        batch_size = x.shape[0]
        arm_features = x.reshape(batch_size, self.n_arms, 3)
        normalized_arm_features = arm_features.clone()
        normalized_arm_features[:, :, 0] = normalized_arm_features[:, :, 0] / float(
            self.max_actions
        )
        arm_embeddings = self._embed_arms(normalized_arm_features)
        arm_logits = self._compute_arm_logits(arm_embeddings)
        stop_logit = self._compute_stop_logit(arm_features)
        return arm_logits, stop_logit

    def train_q_learning(
        self,
        bandit,
        optimizer,
        num_steps=1000,
        epsilon=0.1,
        epsilon_decay=0.995,
        min_epsilon=0.01,
        buffer_size=10000,
        batch_size=32,
        target_update_freq=100,
        gamma=0.99,
    ):




        # Experience replay buffer
        replay_buffer = deque(maxlen=buffer_size)

        # Target network (not strictly necessary for stateless bandits)
        target_net = MultiArmedBanditNet(self.n_arms, self.max_actions, self.dim_arm)
        target_net.load_state_dict(self.state_dict())
        target_net.eval()

        metrics = {
            "losses": [],
            "rewards": [],
            "cumulative_regret": [],
            "epsilon_values": [],
            "q_values": [],
        }

        current_epsilon = epsilon
        total_reward = 0.0
        #  THIS IS THE PSEUDOCODE  -----------------------------------------------------------------------

        for step in range(num_steps):
            state = bandit.get_current_state().unsqueeze(0)


            # Training logic
            # POPULATING THE REPLAY BUFFER
            decision_value = random.random()
            if(decision_value > current_epsilon):
                # Exploit
                # use polilcy network
                with torch.no_grad():
                  predict_q_values = policy_network(state)
                arm_index = np.argmax(predict_q_values)
                # We know the reward right away from the ARM
                reward = bandit.pull_arm(arm_index)
                replay_buffer.append(state,arm_index, reward)
                next_state = bandit.get_current_state()
                # We do not allow the player to stop
            else:
              # Explore, take a random action
              arm_index= random.randomint(0,bandit_arms -1)
              reward = bandit.pull_arm(arm_index)
              replay_buffer.append(state,arm_index, reward)
              next_state = bandit.get_current_state()



            current_epsilon= epsilon_decay * current_epsilon


            if step % 100 == 0:
                with torch.no_grad():
                    q_vals = self.forward(state)
                    metrics["q_values"].append(q_vals.mean().item())
                metrics["rewards"].append(total_reward)
                metrics["cumulative_regret"].append(bandit.get_cumulative_regret())
                metrics["epsilon_values"].append(current_epsilon)

          current_epsilon= epsilon_decay * current_epsilon
          # Use the replay buffer for actual training
          # When we do Q learning, we just care about
          batch = random.choices(replay_buffer, batch_size)
          states, next_steps, rewards = zip(*batch)
          state_tensor = torch.sensor(states)
          q_value_tensor = policy_network(state_tensor)

        # Debug output every 200 steps (moved outside the if step % 100 block)
        if step % 200 == 0:
            current_loss = metrics["losses"][-1] if metrics["losses"] else 0.0

            with torch.no_grad():
                q_vals = self.forward(state)
                q_vals_np = q_vals.cpu().numpy().flatten()

            print(
                f"Step {step}: Loss={current_loss:.4f}, Reward={total_reward:.3f}, "
                f"Regret={bandit.get_cumulative_regret():.3f}, "
                f"Epsilon={current_epsilon:.3f}"
            )
            print(f"  Q-values: [{', '.join([f'{q:.3f}' for q in q_vals_np])}]")

            # Show arm statistics
            stats = bandit.get_all_statistics()
            arm_counts = [arm["count"] for arm in stats["arms"]]
            arm_means = [arm["mean_reward"] for arm in stats["arms"]]
            print(f"  Arm pulls: {arm_counts}")
            print(f"  Arm means: [{', '.join([f'{r:.3f}' for r in arm_means])}]")

            # Show best action
            best_action = torch.argmax(q_vals).item()
            if best_action < self.n_arms:
                print(f"  Best action: Arm {best_action}")
            else:
                print(f"  Best action: Stop")
            print()

        # Final debug output
        print(f"\nFinal Training Summary:")
        with torch.no_grad():
            final_state = bandit.get_current_state().unsqueeze(0)
            final_q_vals = self.forward(final_state)
            final_q_vals_np = final_q_vals.cpu().numpy().flatten()
            best_final_action = torch.argmax(final_q_vals).item()

        print(
            f"Final Q-values (at end of training): [{', '.join([f'{q:.3f}' for q in final_q_vals_np])}]"
        )
        if best_final_action < self.n_arms:
            print(f"Final preferred action: Arm {best_final_action}")
        else:
            print(f"Final preferred action: Stop")

        stats = bandit.get_all_statistics()
        arm_counts = [arm["count"] for arm in stats["arms"]]
        arm_means = [arm["mean_reward"] for arm in stats["arms"]]
        print(f"Final arm pulls: {arm_counts}")
        print(f"Final arm means: [{', '.join([f'{r:.3f}' for r in arm_means])}]")
        print(f"Total training steps: {num_steps}")
        print()

        # Store final values for consistent reporting
        metrics["final_q_values"] = final_q_vals_np
        metrics["final_state"] = final_state
        metrics["final_best_action"] = best_final_action

        return metrics

    def train_reinforce(
        self,
        bandit,
        optimizer,
        num_episodes=500,
        max_steps_per_episode=20,
        temperature=1.0,
        baseline_lr=0.01,
    ):
        # Baseline network for variance reduction
        baseline_net = torch.nn.Sequential(
            torch.nn.Linear(3 * self.n_arms, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1),
        )
        baseline_optimizer = optim.Adam(baseline_net.parameters(), lr=baseline_lr)

        metrics = {
            "episode_returns": [],
            "episode_lengths": [],
            "policy_losses": [],
            "baseline_losses": [],
            "cumulative_regret": [],
        }

        arm_selection_counts = [0] * self.n_arms
        stop_count = 0

        for episode in range(num_episodes):
            states = []
            actions = []
            rewards = []
            log_probs = []
            episode_arm_selections = [0] * self.n_arms

            for step in range(max_steps_per_episode):
                # Complete the episode

            if len(rewards) == 0:
                # no steps taken
                continue

            # Compute discounted returns with gamma=0.99
            returns = torch.zeros(len(rewards), dtype=torch.float32)
            g_acc = 0.0
            gamma = 0.99
            for t in range(len(rewards) - 1, -1, -1):
                # TODO

            # Stack log_probs and states
            log_probs = torch.stack(log_probs).squeeze(-1)
            states_tensor = torch.stack(states).squeeze(1)

            if log_probs.dim() == 0:
                log_probs = log_probs.unsqueeze(0)
            if states_tensor.dim() == 1:
                states_tensor = states_tensor.unsqueeze(0)

            # Enhanced baseline using state features
            baseline_values = baseline_net(states_tensor).squeeze()
            if baseline_values.dim() == 0:
                baseline_values = baseline_values.unsqueeze(0)

            advantages = returns - baseline_values.detach()

            # Normalize advantages for stability
            if len(advantages) > 1:
                advantages = (advantages - advantages.mean()) / (
                    advantages.std() + 1e-8
                )

            # Policy loss with importance weighting
            # TODO

            # Baseline loss
            # TODO

            optimizer.zero_grad()
            policy_loss.backward()
            optimizer.step()

            baseline_optimizer.zero_grad()
            baseline_loss.backward()
            torch.nn.utils.clip_grad_norm_(baseline_net.parameters(), max_norm=0.5)
            baseline_optimizer.step()

            episode_return = float(returns[0].item())
            metrics["episode_returns"].append(episode_return)
            metrics["episode_lengths"].append(len(rewards))
            metrics["policy_losses"].append(float(policy_loss.item()))
            metrics["baseline_losses"].append(float(baseline_loss.item()))
            metrics["cumulative_regret"].append(bandit.get_cumulative_regret())

            # Enhanced progress reporting
            if episode % 50 == 0:
                last = (
                    metrics["episode_returns"][-50:]
                    if len(metrics["episode_returns"]) >= 50
                    else metrics["episode_returns"]
                )
                avg_return = np.mean(last) if len(last) > 0 else 0.0
                avg_length = (
                    np.mean(metrics["episode_lengths"][-50:])
                    if len(metrics["episode_lengths"]) >= 1
                    else 0.0
                )

                # Calculate recent arm preferences
                recent_episodes = min(50, episode + 1)
                recent_selections = arm_selection_counts.copy()
                recent_total = sum(recent_selections) + stop_count

                print(
                    f"Episode {episode:3d}: Return={avg_return:6.3f}, Length={avg_length:4.1f}, "
                    f"Loss={metrics['policy_losses'][-1]:6.4f}, Regret={bandit.get_cumulative_regret():6.1f}"
                )

                if recent_total > 0:
                    arm_prefs = [
                        f"{100 * count / recent_total:4.1f}%"
                        for count in recent_selections
                    ]
                    stop_pref = f"{100 * stop_count / recent_total:4.1f}%"
                    print(f"         Arms: [{', '.join(arm_prefs)}], Stop: {stop_pref}")

        # Final analysis
        print(f"\n📊 REINFORCE Training Summary:")
        print(f"   Total episodes: {num_episodes}")
        total_selections = sum(arm_selection_counts) + stop_count
        if total_selections > 0:
            print(f"   Action distribution:")
            for i, count in enumerate(arm_selection_counts):
                pct = 100 * count / total_selections
                print(f"     Arm {i}: {count:4d} selections ({pct:5.1f}%)")
            stop_pct = 100 * stop_count / total_selections
            print(f"     Stop:   {stop_count:4d} selections ({stop_pct:5.1f}%)")
        print()

        return metrics

    def evaluate_policy(
        self, bandit, num_episodes=100, max_steps_per_episode=20, deterministic=True
    ):
        self.eval()
        total_rewards = []
        episode_lengths = []
        regrets = []

        # Save current bandit state
        original_pull_counts = bandit.pull_counts.clone()
        original_total_rewards = bandit.total_rewards.clone()
        original_sum_squared_rewards = bandit.sum_squared_rewards.clone()
        original_pull_history = bandit.pull_history.copy()
        original_total_pulls = bandit.total_pulls

        with torch.no_grad():
            for episode in range(num_episodes):
                # Reset to original state for each episode
                bandit.pull_counts = original_pull_counts.clone()
                bandit.total_rewards = original_total_rewards.clone()
                bandit.sum_squared_rewards = original_sum_squared_rewards.clone()
                bandit.pull_history = original_pull_history.copy()
                bandit.total_pulls = original_total_pulls

                episode_reward = 0.0
                steps_taken = 0
                initial_regret = bandit.get_cumulative_regret()

                for step in range(max_steps_per_episode):
                    state = bandit.get_current_state().unsqueeze(0)
                    action = self.select_action(
                        state, deterministic=deterministic
                    ).item()
                    if action < self.n_arms:
                        reward = bandit.pull_arm(action)
                        episode_reward += reward
                        steps_taken += 1
                    else:
                        break

                total_rewards.append(episode_reward)
                episode_lengths.append(steps_taken)
                regrets.append(bandit.get_cumulative_regret() - initial_regret)

        # Restore original bandit state
        bandit.pull_counts = original_pull_counts
        bandit.total_rewards = original_total_rewards
        bandit.sum_squared_rewards = original_sum_squared_rewards
        bandit.pull_history = original_pull_history
        bandit.total_pulls = original_total_pulls

        self.train()

        return {
            "mean_reward": float(np.mean(total_rewards)),
            "std_reward": float(np.std(total_rewards)),
            "mean_length": float(np.mean(episode_lengths)),
            "std_length": float(np.std(episode_lengths)),
            "mean_regret": float(np.mean(regrets)),
            "std_regret": float(np.std(regrets)),
        }

    def evaluate_policy_simple(self, bandit, num_actions=100, deterministic=True):
        """
        Simple evaluation that doesn't reset bandit history.
        Evaluates the learned policy by taking actions and measuring performance.
        """
        self.eval()
        initial_regret = bandit.get_cumulative_regret()
        initial_total_reward = sum(reward for _, reward in bandit.pull_history)

        actions_taken = 0
        episode_reward = 0.0

        with torch.no_grad():
            for _ in range(num_actions):
                state = bandit.get_current_state().unsqueeze(0)

                if deterministic:
                    q_values = self.forward(state)
                    action = torch.argmax(q_values).item()
                else:
                    action = self.select_action(
                        state, deterministic=deterministic
                    ).item()

                if action < self.n_arms:
                    reward = bandit.pull_arm(action)
                    episode_reward += reward
                    actions_taken += 1
                else:
                    # Don't break on stop action during evaluation - force exploration
                    forced_action = torch.randint(0, self.n_arms, (1,)).item()
                    reward = bandit.pull_arm(forced_action)
                    episode_reward += reward
                    actions_taken += 1

        final_regret = bandit.get_cumulative_regret()
        evaluation_regret = final_regret - initial_regret

        self.train()

        return {
            "total_reward": episode_reward,
            "mean_reward": episode_reward / max(1, actions_taken),
            "actions_taken": actions_taken,
            "evaluation_regret": evaluation_regret,
        }

In [None]:
def create_bandit_state(
    n_arms, usage_counts, average_payouts, std_deviations, device="cpu"
):
    if (
        len(usage_counts) != n_arms
        or len(average_payouts) != n_arms
        or len(std_deviations) != n_arms
    ):
        raise ValueError(
            "Usage counts, average payouts, and standard deviations must have length n_arms"
        )

    state = torch.zeros(1, 3 * n_arms, device=device)
    for i in range(n_arms):
        state[0, 3 * i] = usage_counts[i]
        state[0, 3 * i + 1] = average_payouts[i]
        state[0, 3 * i + 2] = std_deviations[i]
    return state

In [None]:
def train_and_evaluate_q_learning(
    n_arms=5,
    bandit_means=None,
    bandit_variances=None,
    num_steps=2000,
    learning_rate=0.01,
):

    if bandit_means is None:
        bandit_means = [0.1, 0.5, 0.3, 0.8, 0.2]
    if bandit_variances is None:
        bandit_variances = [0.1] * n_arms

    bandit = MultiArmedBandit(
        n_arms=n_arms, means=bandit_means, variances=bandit_variances, random_seed=42
    )

    net = MultiArmedBanditNet(n_arms, max_actions=50)
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)

    print("╔" + "═" * 58 + "╗")
    print("║" + " Q-LEARNING TRAINING ".center(58) + "║")
    print("╚" + "═" * 58 + "╝")
    print(f"🎯 Problem Setup:")
    print(f"   Arms: {n_arms}")
    print(f"   True means: {bandit_means}")
    print(
        f"   Optimal arm: {bandit.get_optimal_arm()} (mean = {max(bandit_means):.3f})"
    )
    print(f"   Training steps: {num_steps}")
    print()

    # Store initial state for comparison
    initial_regret = bandit.get_cumulative_regret()

    metrics = net.train_q_learning(
        bandit,
        optimizer,
        num_steps=num_steps,
        epsilon=0.5,
        min_epsilon=0.1,
        epsilon_decay=0.995,
    )

    # Use simple evaluation that preserves learned bandit state
    eval_metrics_simple = net.evaluate_policy_simple(bandit, num_actions=100)

    # Also try traditional evaluation for comparison
    eval_metrics = net.evaluate_policy(bandit, num_episodes=100)

    if metrics.get("losses"):
        print("Final loss: {:.4f}".format(metrics["losses"][-1]))
    if metrics.get("cumulative_regret"):
        training_regret = metrics["cumulative_regret"][-1] - initial_regret
        print("Training regret: {:.3f}".format(training_regret))
        print(
            "Final cumulative regret: {:.3f}".format(metrics["cumulative_regret"][-1])
        )

    # Show both evaluation results
    print("\n--- Simple Evaluation (preserves bandit state) ---")
    print("Mean reward: {:.3f}".format(eval_metrics_simple["mean_reward"]))
    print("Actions taken: {}".format(eval_metrics_simple["actions_taken"]))
    print("Evaluation regret: {:.3f}".format(eval_metrics_simple["evaluation_regret"]))

    print("\n--- Traditional Evaluation (resets each episode) ---")
    print(
        "Mean reward: {:.3f} ± {:.3f}".format(
            eval_metrics["mean_reward"], eval_metrics["std_reward"]
        )
    )
    print("Mean regret: {:.3f}".format(eval_metrics["mean_regret"]))

    # Comprehensive learning analysis
    print("╔" + "═" * 58 + "╗")
    print("║" + " Q-LEARNING RESULTS ANALYSIS ".center(58) + "║")
    print("╚" + "═" * 58 + "╝")

    # Use stored final Q-values for consistency
    if "final_q_values" in metrics:
        final_q_vals_np = metrics["final_q_values"]
        best_action = metrics["final_best_action"]
    else:
        # Fallback if not stored (shouldn't happen)
        with torch.no_grad():
            final_state = bandit.get_current_state().unsqueeze(0)
            final_q_vals = net.forward(final_state)
            final_q_vals_np = final_q_vals.cpu().numpy().flatten()
            best_action = torch.argmax(final_q_vals).item()

    print(f"🧠 Learned Q-values (consistent with training end):")
    for i in range(len(bandit_means)):
        print(
            f"   Arm {i}: Q={final_q_vals_np[i]:7.3f} (true mean={bandit_means[i]:.3f})"
        )
    print(f"   Stop:  Q={final_q_vals_np[-1]:7.3f}")
    print()

    print(f"🎯 Decision Analysis:")
    optimal_arm = bandit.get_optimal_arm()
    if best_action < len(bandit_means):
        print(
            f"   Agent prefers: Arm {best_action} (mean = {bandit_means[best_action]:.3f})"
        )
        if best_action == optimal_arm:
            print("   ✅ SUCCESS: Learned to prefer optimal arm!")
        else:
            print(f"   ⚠️  SUBOPTIMAL: Optimal is arm {optimal_arm}")
    else:
        print("   ❌ PROBLEM: Agent prefers to stop")

    # Performance assessment
    optimal_reward = max(bandit_means)
    performance_ratio = eval_metrics_simple["mean_reward"] / optimal_reward
    print(f"\n📈 Performance Metrics:")
    print(f"   Achieved: {eval_metrics_simple['mean_reward']:.3f}")
    print(f"   Optimal:  {optimal_reward:.3f}")
    print(f"   Ratio:    {performance_ratio:.1%}")

    if performance_ratio >= 0.95:
        print("   🏆 EXCELLENT: >95% of optimal!")
    elif performance_ratio >= 0.8:
        print("   ✅ GOOD: >80% of optimal")
    elif performance_ratio >= 0.6:
        print("   ⚠️  FAIR: >60% of optimal")
    else:
        print("   ❌ POOR: <60% of optimal")

    print()

    return net, metrics, eval_metrics

In [None]:
def train_and_evaluate_reinforce(
    n_arms=5,
    bandit_means=None,
    bandit_variances=None,
    num_episodes=1000,
    learning_rate=0.01,
):
    if bandit_means is None:
        bandit_means = [0.1, 0.5, 0.3, 0.8, 0.2]
    if bandit_variances is None:
        bandit_variances = [0.1] * n_arms

    bandit = MultiArmedBandit(
        n_arms=n_arms, means=bandit_means, variances=bandit_variances, random_seed=42
    )

    net = MultiArmedBanditNet(n_arms, max_actions=50)
    optimizer = optim.Adam(
        net.parameters(), lr=learning_rate * 10
    )  # Higher learning rate for REINFORCE

    print("╔" + "═" * 58 + "╗")
    print("║" + " REINFORCE TRAINING ".center(58) + "║")
    print("╚" + "═" * 58 + "╝")
    print(f"🎯 Problem Setup:")
    print(f"   Arms: {n_arms}")
    print(f"   True means: {bandit_means}")
    print(
        f"   Optimal arm: {bandit.get_optimal_arm()} (mean = {max(bandit_means):.3f})"
    )
    print(f"   Episodes: {num_episodes}")
    print()

    # Pre-populate bandit with some initial experience to help learning
    print("🔄 Pre-populating bandit with initial experience...")
    for arm in range(n_arms):
        for _ in range(3):  # Give each arm 3 initial pulls
            bandit.pull_arm(arm)

    metrics = net.train_reinforce(
        bandit, optimizer, num_episodes=num_episodes, temperature=1.5, baseline_lr=0.001
    )
    eval_metrics = net.evaluate_policy(bandit, num_episodes=100)

    print("╔" + "═" * 58 + "╗")
    print("║" + " REINFORCE RESULTS ANALYSIS ".center(58) + "║")
    print("╚" + "═" * 58 + "╝")

    print(f"📊 Training Metrics:")
    if metrics.get("policy_losses"):
        print(f"   Final policy loss: {metrics['policy_losses'][-1]:.4f}")
    if metrics.get("cumulative_regret"):
        print(f"   Final cumulative regret: {metrics['cumulative_regret'][-1]:.3f}")

    # Learning curve analysis
    if len(metrics["episode_returns"]) >= 100:
        early_returns = np.mean(metrics["episode_returns"][:50])
        late_returns = np.mean(metrics["episode_returns"][-50:])
        improvement = late_returns - early_returns
        print(f"   Early episodes (0-49): {early_returns:.3f} avg return")
        print(f"   Late episodes (-50:-1): {late_returns:.3f} avg return")
        print(f"   Improvement: {improvement:+.3f}")

    print(f"\n📈 Evaluation Results:")
    print(
        f"   Mean reward: {eval_metrics['mean_reward']:.3f} ± {eval_metrics['std_reward']:.3f}"
    )
    print(
        f"   Mean episode length: {eval_metrics['mean_length']:.1f} ± {eval_metrics['std_length']:.1f}"
    )
    print(f"   Mean regret: {eval_metrics['mean_regret']:.3f}")

    # Performance assessment
    optimal_reward = max(bandit_means)
    performance_ratio = (
        eval_metrics["mean_reward"] / (eval_metrics["mean_length"] * optimal_reward)
        if eval_metrics["mean_length"] > 0
        else 0
    )
    print(f"\n🎯 Performance Analysis:")
    print(
        f"   Per-step reward: {eval_metrics['mean_reward'] / max(1, eval_metrics['mean_length']):.3f}"
    )
    print(f"   Optimal per-step: {optimal_reward:.3f}")
    print(f"   Efficiency: {performance_ratio:.1%}")

    if performance_ratio >= 0.8:
        print("   🏆 EXCELLENT: >80% efficient!")
    elif performance_ratio >= 0.6:
        print("   ✅ GOOD: >60% efficient")
    elif performance_ratio >= 0.4:
        print("   ⚠️  FAIR: >40% efficient")
    else:
        print("   ❌ POOR: <40% efficient")

    print()

    return net, metrics, eval_metrics

In [None]:
print("╔" + "═" * 78 + "╗")
print("║" + " MULTI-ARMED BANDIT NEURAL NETWORK TRAINING DEMO ".center(78) + "║")
print("╚" + "═" * 78 + "╝")
print()

n_arms = 5
max_actions = 100

bandit_net = MultiArmedBanditNet(
    n_arms, max_actions, dim_arm=64, hidden_layers=[64, 32], dropout_rate=0.1
)

print("🔧 Network Architecture:")
print(f"   Model: {type(bandit_net).__name__}")
print(f"   Arms: {n_arms}")
print(f"   Max actions: {max_actions}")
print(f"   Parameters: {sum(p.numel() for p in bandit_net.parameters()):,}")
print()

# Test with empty state
state = create_bandit_state(n_arms, [0] * n_arms, [0.0] * n_arms, [0.0] * n_arms)
logits = bandit_net(state)
probabilities = bandit_net.get_action_probabilities(state)

print("🧪 Network Testing (empty state):")
print(f"   Output shape: {logits.shape}")
print(
    f"   Action probabilities: uniform = {probabilities.squeeze().detach().numpy()}"
)
print()

print("┌" + "─" * 76 + "┐")
print("│" + " TRAINING DEMONSTRATIONS ".center(76) + "│")
print("└" + "─" * 76 + "┘")
print()

# Q-learning demonstration
try:
    print("🎯 Starting Q-Learning demonstration...")
    q_net, q_metrics, q_eval = train_and_evaluate_q_learning(
        n_arms=3, bandit_means=[0.3, 0.7, 0.1], num_steps=1000
    )
    print("✅ Q-learning demonstration completed successfully!")
except Exception as e:
    print(f"❌ Q-learning demo failed: {e}")
    import traceback

    traceback.print_exc()

print("\n" + "─" * 80 + "\n")

# REINFORCE demonstration
try:
    print("🎯 Starting REINFORCE demonstration...")
    r_net, r_metrics, r_eval = train_and_evaluate_reinforce(
        n_arms=3,
        bandit_means=[0.3, 0.7, 0.1],
        num_episodes=400,
        learning_rate=0.005,
    )
    print("✅ REINFORCE demonstration completed successfully!")
except Exception as e:
    print(f"❌ REINFORCE demo failed: {e}")
    import traceback

    traceback.print_exc()

print("\n" + "─" * 80 + "\n")

# Educational comparison: Untrained vs Trained models
print("🧪 Learning Effect Demonstration:")

try:
    # Create a realistic test scenario
    demo_bandit = MultiArmedBandit(
        n_arms=3, means=[0.3, 0.7, 0.1], variances=[0.1, 0.1, 0.1], random_seed=456
    )

    # Simulate some exploration experience
    print("   Creating realistic bandit state...")
    print("   - Pulling each arm several times to build statistics")

    for arm in range(3):
        for _ in range(8):  # Each arm gets 8 pulls
            demo_bandit.pull_arm(arm)

    demo_state = demo_bandit.get_current_state().unsqueeze(0)

    # Show the bandit statistics
    stats = demo_bandit.get_all_statistics()
    print("\n   📊 Bandit Statistics After Exploration:")
    for i, arm_stat in enumerate(stats["arms"]):
        print(
            f"     Arm {i}: {arm_stat['count']:2d} pulls, "
            f"{arm_stat['mean_reward']:.3f} observed mean, "
            f"{arm_stat['true_mean']:.3f} true mean"
        )

    print(f"   🎯 Optimal choice should be: Arm {demo_bandit.get_optimal_arm()}")

    # Compare untrained vs trained (if available)
    print("\n   🆚 Model Comparison:")

    # Untrained network
    untrained_net = MultiArmedBanditNet(3, max_actions=50)
    with torch.no_grad():
        untrained_prefs = untrained_net.get_action_probabilities(demo_state)
        untrained_choice = torch.argmax(untrained_prefs).item()

    print(f"   📍 Untrained network:")
    print(
        f"     Preferences: {[f'{p:.3f}' for p in untrained_prefs.squeeze().detach().numpy()]}"
    )
    print(
        f"     Choice: {'Arm ' + str(untrained_choice) if untrained_choice < 3 else 'Stop'}"
    )

    # Trained networks (if available)
    if "q_net" in locals():
        q_net.eval()
        with torch.no_grad():
            q_prefs = q_net.get_action_probabilities(demo_state)
            q_choice = torch.argmax(q_prefs).item()

        print(f"   🎯 Q-learning (trained):")
        print(
            f"     Preferences: {[f'{p:.3f}' for p in q_prefs.squeeze().detach().numpy()]}"
        )
        print(f"     Choice: {'Arm ' + str(q_choice) if q_choice < 3 else 'Stop'}")

        if q_choice == demo_bandit.get_optimal_arm():
            print("     ✅ Correctly identifies optimal arm!")
        else:
            print("     ❌ Suboptimal choice")

    if "r_net" in locals():
        r_net.eval()
        with torch.no_grad():
            r_prefs = r_net.get_action_probabilities(demo_state)
            r_choice = torch.argmax(r_prefs).item()

        print(f"   🎯 REINFORCE (trained):")
        print(
            f"     Preferences: {[f'{p:.3f}' for p in r_prefs.squeeze().detach().numpy()]}"
        )
        print(f"     Choice: {'Arm ' + str(r_choice) if r_choice < 3 else 'Stop'}")

        if r_choice == demo_bandit.get_optimal_arm():
            print("     ✅ Correctly identifies optimal arm!")
        else:
            print("     ❌ Suboptimal choice")

    if "q_net" not in locals() and "r_net" not in locals():
        print("   (No trained models available for comparison)")

except Exception as e:
    print(f"   Demonstration failed: {e}")
    import traceback

    traceback.print_exc()

print("\n" + "╔" + "═" * 78 + "╗")
print("║" + " DEMO COMPLETE ".center(78) + "║")
print("╚" + "═" * 78 + "╝")

╔══════════════════════════════════════════════════════════════════════════════╗
║               MULTI-ARMED BANDIT NEURAL NETWORK TRAINING DEMO                ║
╚══════════════════════════════════════════════════════════════════════════════╝

🔧 Network Architecture:
   Model: MultiArmedBanditNet
   Arms: 5
   Max actions: 100
   Parameters: 11,810

🧪 Network Testing (empty state):
   Output shape: torch.Size([1, 6])
   Action probabilities: uniform = [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]

┌────────────────────────────────────────────────────────────────────────────┐
│                          TRAINING DEMONSTRATIONS                           │
└────────────────────────────────────────────────────────────────────────────┘

🎯 Starting Q-Learning demonstration...
╔══════════════════════════════════════════════════════════╗
║                   Q-LEARNING TRAINING                    ║
╚══════════════════════════════════════════════════════════╝
🎯 Problem Setup