# Train Reinforcement Learning Agent

In [None]:
# Stable-Baselines3 and Gymnasium Implementation
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import PPO, DQN, A2C
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
import os

# Import our functions
from main import updated_fisher_information, reward_function

In [None]:
class QuantumNetworkEnv(gym.Env):
    """
    Gymnasium Environment for Quantum Network Protocol Selection
    
    The agent receives theta values and current Fisher information, then selects
    protocols to maximize cumulative reward over multiple steps.
    """
    
    metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': 4}
    
    def __init__(self, max_steps=10, render_mode=None):
        super().__init__()
        
        self.max_steps = max_steps
        self.render_mode = render_mode
        self.step_count = 0
        
        # Action space: Combined root (3 options) and protocol (6 options) selection
        # Total actions = 3 roots × 6 protocols = 18 possible actions
        # Action encoding: action = root * 6 + protocol
        # Where root ∈ {0, 1, 2} and protocol ∈ {0, 1, 2, 3, 4, 5}
        self.action_space = spaces.Discrete(18)
        
        # Observation space: [theta0, theta1, theta2, fisher0, fisher1, fisher2]
        # theta values are in (0.1, 0.9), Fisher info can be large positive values
        self.observation_space = spaces.Box(
            low=np.array([0.1, 0.1, 0.1, 0.0, 0.0, 0.0], dtype=np.float32),
            high=np.array([0.9, 0.9, 0.9, 1000.0, 1000.0, 1000.0], dtype=np.float32),
            dtype=np.float32
        )
        
        # Initialize state variables
        self.theta = None
        self.fisher_info = None
        self.total_reward = 0
        self.episode_history = []
        
    def reset(self, seed=None, options=None):
        """Reset environment to initial state"""
        super().reset(seed=seed)
        
        if seed is not None:
            np.random.seed(seed)
        
        self.step_count = 0
        self.total_reward = 0
        self.episode_history = []
        
        # Initialize theta values randomly
        self.theta = [np.random.uniform(0.05, 0.45) for _ in range(3)]
        
        # Initialize Fisher information to zeros
        self.fisher_info = [0.0, 0.0, 0.0]
        
        observation = self._get_observation()
        info = self._get_info()
        
        return observation, info
    
    def step(self, action):
        """Execute one step in the environment"""
        if not self.action_space.contains(action):
            raise ValueError(f"Invalid action: {action}. Must be in {self.action_space}")
        
        # Decode the combined action into root and protocol
        # action = root * 6 + protocol
        root = action // 6  # Integer division to get root (0, 1, or 2)
        protocol = action % 6  # Modulo to get protocol (0, 1, 2, 3, 4, 5)
        
        # Store old fisher info for reward calculation
        old_fisher_info = self.fisher_info.copy()
        
        try:
            # Update Fisher information using the selected root and protocol
            new_fisher_contribution = updated_fisher_information(root, protocol, self.theta)
            
            # Add new contribution to existing Fisher information
            self.fisher_info = [old + new for old, new in zip(self.fisher_info, new_fisher_contribution)]
            
            # Calculate reward
            reward = reward_function(old_fisher_info, new_fisher_contribution)
            
        except ZeroDivisionError:
            # Handle edge case where theta values cause division by zero
            reward = -10.0  # Penalty for invalid state
            new_fisher_contribution = [0.0, 0.0, 0.0]
        
        self.total_reward += reward
        self.step_count += 1
        
        # Check if episode is done
        terminated = self.step_count >= self.max_steps
        truncated = False  # We don't have time limits beyond max_steps
        
        # Add small noise to theta values to simulate environment changes
        theta_noise = 0.01
        self.theta = [
            np.clip(t + np.random.uniform(-theta_noise, theta_noise), 0.1, 0.9)
            for t in self.theta
        ]
        
        # Store step information for analysis
        step_info = {
            'step': self.step_count,
            'action': action,
            'root': root,
            'protocol': protocol,
            'reward': reward,
            'fisher_contribution': new_fisher_contribution.copy(),
            'cumulative_fisher': self.fisher_info.copy(),
            'theta': self.theta.copy()
        }
        self.episode_history.append(step_info)
        
        observation = self._get_observation()
        info = self._get_info()
        info.update(step_info)
        
        return observation, reward, terminated, truncated, info
    
    def _get_observation(self):
        """Get current observation: [theta0, theta1, theta2, fisher0, fisher1, fisher2]"""
        obs = np.array(self.theta + self.fisher_info, dtype=np.float32)
        # Clip to ensure it's within observation space bounds
        return np.clip(obs, self.observation_space.low, self.observation_space.high)
    
    def _get_info(self):
        """Get additional information about the current state"""
        return {
            'step_count': self.step_count,
            'total_reward': self.total_reward,
            'theta': self.theta.copy(),
            'fisher_info': self.fisher_info.copy()
        }
    
    def decode_action(self, action):
        """Decode combined action into root and protocol"""
        root = action // 6
        protocol = action % 6
        return root, protocol
    
    def encode_action(self, root, protocol):
        """Encode root and protocol into combined action"""
        return root * 6 + protocol
    
    def render(self):
        """Render the environment (optional)"""
        if self.render_mode == 'human':
            print(f"Step: {self.step_count}/{self.max_steps}")
            print(f"Theta: {[f'{x:.3f}' for x in self.theta]}")
            print(f"Fisher Info: {[f'{x:.3f}' for x in self.fisher_info]}")
            print(f"Total Reward: {self.total_reward:.3f}")
            print("-" * 40)
    
    def close(self):
        """Clean up environment"""
        pass

In [None]:
# Create and test the environment with updated action space
print("Creating Quantum Network Environment with Root + Protocol Selection...")
env = QuantumNetworkEnv(max_steps=10)

# Check if the environment follows Gymnasium API
print("Checking environment...")
try:
    check_env(env, warn=True)
    print("✓ Environment check passed!")
except Exception as e:
    print(f"❌ Environment check failed: {e}")

# Test the environment
print("\nTesting environment reset...")
obs, info = env.reset(seed=42)
print(f"Initial observation shape: {obs.shape}")
print(f"Initial observation: {obs}")
print(f"Action space: {env.action_space} (18 total actions)")
print(f"Observation space: {env.observation_space}")

# Demonstrate action encoding
print("\nAction Encoding Examples:")
print("Action = Root * 6 + Protocol")
for root in range(3):
    for protocol in range(6):
        encoded_action = env.encode_action(root, protocol)
        decoded_root, decoded_protocol = env.decode_action(encoded_action)
        print(f"Root {root}, Protocol {protocol} → Action {encoded_action}")
print()

# Test a few random steps with detailed output
print("Testing random steps with root and protocol details...")
for i in range(3):
    action = env.action_space.sample()
    root, protocol = env.decode_action(action)
    obs, reward, terminated, truncated, info = env.step(action)
    print(f"Step {i+1}: Action={action} (Root={root}, Protocol={protocol}), Reward={reward:.3f}, Done={terminated}")
    
env.close()
print("Environment test completed successfully! ✓")

In [None]:
# Train a PPO agent using Stable-Baselines3
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
import os

print("Setting up training environment...")

# Create training and evaluation environments
train_env = QuantumNetworkEnv(max_steps=20)
eval_env = QuantumNetworkEnv(max_steps=20)

# Wrap environments with Monitor for logging
train_env = Monitor(train_env)
eval_env = Monitor(eval_env)

# Create the PPO model
print("Creating PPO model...")
model = PPO(
    "MlpPolicy",  # Multi-layer perceptron policy
    train_env,
    verbose=1,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.01,  # Encourage exploration
    device="cpu",  # Use CPU for compatibility
    tensorboard_log="./ppo_quantum_tensorboard/"
)

print("Model created successfully!")
print(f"Policy architecture: {model.policy}")

# Create evaluation callback
eval_callback = EvalCallback(
    eval_env,
    best_model_save_path="./best_quantum_model/",
    log_path="./eval_logs/",
    eval_freq=1000,
    deterministic=True,
    render=False,
    verbose=1
)

print("Training setup completed! Ready to train the agent.")

In [None]:
# Train the model
print("Starting training...")
print("Training for 50,000 timesteps (adjust as needed)")

# Train the model
model.learn(
    total_timesteps=50000,
    callback=eval_callback,
    tb_log_name="PPO_QuantumNetwork",
    progress_bar=True
)

print("Training completed!")

# Save the final model
model.save("quantum_network_ppo_final")
print("Model saved as 'quantum_network_ppo_final'")

# Load and test the trained model
print("\nTesting the trained model...")
trained_model = PPO.load("quantum_network_ppo_final")

# Test on a few episodes with detailed action analysis
test_env = QuantumNetworkEnv(max_steps=10)
total_rewards = []

for episode in range(5):
    obs, info = test_env.reset(seed=episode)
    episode_reward = 0
    episode_actions = []
    episode_details = []
    
    for step in range(10):
        action, _states = trained_model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = test_env.step(action)
        episode_reward += reward
        episode_actions.append(action)
        
        # Decode action for detailed analysis
        root, protocol = test_env.decode_action(action)
        episode_details.append(f"R{root}P{protocol}")
        
        if terminated or truncated:
            break
    
    total_rewards.append(episode_reward)
    print(f"Episode {episode + 1}: Reward = {episode_reward:.3f}")
    print(f"  Actions = {episode_actions}")
    print(f"  Details = {episode_details}")

average_reward = np.mean(total_rewards)
print(f"\nAverage reward over 5 test episodes: {average_reward:.3f}")
test_env.close()

In [None]:
# Analysis and Visualization of Agent Performance
import matplotlib.pyplot as plt
from collections import Counter

def analyze_agent_performance(model, num_episodes=100):
    """Analyze the trained agent's root and protocol selection patterns."""
    
    env = QuantumNetworkEnv(max_steps=10)
    all_actions = []
    all_roots = []
    all_protocols = []
    all_rewards = []
    all_theta_values = []
    all_fisher_values = []
    
    print(f"Analyzing agent performance over {num_episodes} episodes...")
    
    for episode in range(num_episodes):
        obs, info = env.reset(seed=episode)
        episode_actions = []
        episode_reward = 0
        
        # Store initial theta and fisher values
        theta_values = obs[:3].tolist()
        fisher_values = obs[3:].tolist()
        
        for step in range(10):
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = env.step(action)
            
            # Decode action
            root, protocol = env.decode_action(action)
            
            episode_actions.append(action)
            all_actions.append(action)
            all_roots.append(root)
            all_protocols.append(protocol)
            episode_reward += reward
            
            if terminated or truncated:
                break
        
        all_rewards.append(episode_reward)
        all_theta_values.append(theta_values)
        all_fisher_values.append(fisher_values)
    
    env.close()
    
    # Analysis
    action_counts = Counter(all_actions)
    root_counts = Counter(all_roots)
    protocol_counts = Counter(all_protocols)
    avg_reward = np.mean(all_rewards)
    
    print(f"\nPerformance Analysis:")
    print(f"Average reward per episode: {avg_reward:.3f}")
    print(f"Root selection frequency:")
    for root in range(3):
        count = root_counts.get(root, 0)
        percentage = (count / len(all_roots)) * 100
        print(f"  Root {root}: {count} times ({percentage:.1f}%)")
    
    print(f"Protocol selection frequency:")
    for protocol in range(6):
        count = protocol_counts.get(protocol, 0)
        percentage = (count / len(all_protocols)) * 100
        print(f"  Protocol {protocol}: {count} times ({percentage:.1f}%)")
    
    # Visualization
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Root and Protocol selection frequency
    # Create a combined plot
    ax1_twin = ax1.twinx()
    
    roots = list(range(3))
    protocols = list(range(6))
    root_frequencies = [root_counts.get(r, 0) for r in roots]
    protocol_frequencies = [protocol_counts.get(p, 0) for p in protocols]
    
    bars1 = ax1.bar([r - 0.2 for r in roots], root_frequencies, width=0.4, 
                    color='skyblue', alpha=0.7, label='Root Selection')
    bars2 = ax1_twin.bar(protocols, protocol_frequencies, width=0.4, 
                         color='lightcoral', alpha=0.7, label='Protocol Selection')
    
    ax1.set_xlabel('Root / Protocol ID')
    ax1.set_ylabel('Root Selection Frequency', color='skyblue')
    ax1_twin.set_ylabel('Protocol Selection Frequency', color='lightcoral')
    ax1.set_title('Root and Protocol Selection Frequency')
    ax1.set_xticks(range(6))
    ax1.tick_params(axis='y', labelcolor='skyblue')
    ax1_twin.tick_params(axis='y', labelcolor='lightcoral')
    
    # 2. Reward distribution
    ax2.hist(all_rewards, bins=20, color='lightgreen', alpha=0.7, edgecolor='black')
    ax2.set_xlabel('Episode Reward')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Episode Reward Distribution')
    ax2.axvline(avg_reward, color='red', linestyle='--', label=f'Mean: {avg_reward:.3f}')
    ax2.legend()
    
    # 3. Theta values distribution
    theta_flat = [theta for episode_thetas in all_theta_values for theta in episode_thetas]
    ax3.hist(theta_flat, bins=20, color='orange', alpha=0.7, edgecolor='black')
    ax3.set_xlabel('Theta Values')
    ax3.set_ylabel('Frequency')
    ax3.set_title('Distribution of Theta Values Encountered')
    
    # 4. Fisher information distribution
    fisher_flat = [fisher for episode_fishers in all_fisher_values for fisher in episode_fishers]
    # Remove infinite values for plotting
    fisher_finite = [f for f in fisher_flat if np.isfinite(f)]
    if fisher_finite:
        ax4.hist(fisher_finite, bins=20, color='purple', alpha=0.7, edgecolor='black')
        ax4.set_xlabel('Fisher Information')
        ax4.set_ylabel('Frequency')
        ax4.set_title('Distribution of Fisher Information Values')
        ax4.set_yscale('log')  # Log scale due to wide range
    
    plt.tight_layout()
    plt.show()
    
    return {
        'avg_reward': avg_reward,
        'action_counts': action_counts,
        'root_counts': root_counts,
        'protocol_counts': protocol_counts,
        'all_rewards': all_rewards
    }

# Run the analysis
print("Analyzing the trained agent...")
results = analyze_agent_performance(trained_model, num_episodes=100)

In [None]:
# Demonstration: Understanding the New Action Space
print("🔧 Action Space Modification Complete!")
print("=" * 50)

print("\n📊 Action Encoding System:")
print("The agent now has full control over both ROOT and PROTOCOL selection!")
print("\nTotal action space: 18 actions (3 roots × 6 protocols)")
print("Encoding: action = root * 6 + protocol")
print("Where:")
print("  • Root ∈ {0, 1, 2} (which node to measure)")
print("  • Protocol ∈ {0, 1, 2, 3, 4, 5} (which protocol to use)")

print("\n📋 Action Mapping Table:")
print("Action | Root | Protocol | Description")
print("-------|------|----------|------------")

env = QuantumNetworkEnv()
for action in range(18):
    root, protocol = env.decode_action(action)
    print(f"  {action:2d}   |  {root}   |    {protocol}     | Root {root}, Protocol {protocol}")

print("\n🎯 Benefits of This Modification:")
print("✓ Agent learns optimal root selection strategy")
print("✓ Full control over both measurement node and protocol")
print("✓ Can discover correlations between theta values and optimal roots")
print("✓ More sophisticated decision-making capability")

print("\n🔄 Before vs After:")
print("Before: Agent selects protocol (6 options), root chosen randomly")
print("After:  Agent selects root AND protocol (18 total combinations)")

env.close()

In [None]:
# Advanced Analysis: Compare Different RL Algorithms
from stable_baselines3 import DQN, A2C

def train_and_compare_algorithms():
    """Train multiple RL algorithms and compare their performance."""
    
    algorithms = {
        'PPO': PPO,
        'DQN': DQN,
        'A2C': A2C
    }
    
    results = {}
    trained_models = {}
    
    for name, algorithm_class in algorithms.items():
        print(f"\nTraining {name} algorithm...")
        
        # Create fresh environment
        train_env = Monitor(QuantumNetworkEnv(max_steps=20))
        
        # Algorithm-specific parameters
        if name == 'DQN':
            model = algorithm_class(
                "MlpPolicy",
                train_env,
                verbose=1,
                learning_rate=1e-3,
                buffer_size=10000,
                learning_starts=1000,
                batch_size=32,
                tau=1.0,
                gamma=0.99,
                train_freq=4,
                gradient_steps=1,
                target_update_interval=1000,
                exploration_fraction=0.1,
                exploration_initial_eps=1.0,
                exploration_final_eps=0.05,
                device="cpu"
            )
        elif name == 'A2C':
            model = algorithm_class(
                "MlpPolicy",
                train_env,
                verbose=1,
                learning_rate=7e-4,
                n_steps=5,
                gamma=0.99,
                gae_lambda=1.0,
                ent_coef=0.01,
                vf_coef=0.25,
                max_grad_norm=0.5,
                device="cpu"
            )
        else:  # PPO
            model = algorithm_class(
                "MlpPolicy",
                train_env,
                verbose=1,
                learning_rate=3e-4,
                n_steps=2048,
                batch_size=64,
                n_epochs=10,
                gamma=0.99,
                gae_lambda=0.95,
                clip_range=0.2,
                ent_coef=0.01,
                device="cpu"
            )
        
        # Train the model (shorter training for comparison)
        model.learn(total_timesteps=20000, progress_bar=True)
        
        # Save model
        model.save(f"quantum_network_{name.lower()}")
        trained_models[name] = model
        
        # Test performance
        test_env = QuantumNetworkEnv(max_steps=10)
        episode_rewards = []
        
        for episode in range(20):
            obs, info = test_env.reset(seed=episode)
            episode_reward = 0
            
            for step in range(10):
                action, _states = model.predict(obs, deterministic=True)
                obs, reward, terminated, truncated, info = test_env.step(action)
                episode_reward += reward
                
                if terminated or truncated:
                    break
            
            episode_rewards.append(episode_reward)
        
        test_env.close()
        train_env.close()
        
        avg_reward = np.mean(episode_rewards)
        std_reward = np.std(episode_rewards)
        
        results[name] = {
            'avg_reward': avg_reward,
            'std_reward': std_reward,
            'episode_rewards': episode_rewards
        }
        
        print(f"{name} - Average Reward: {avg_reward:.3f} ± {std_reward:.3f}")
    
    return results, trained_models

def plot_algorithm_comparison(results):
    """Plot comparison of different algorithms."""
    
    algorithms = list(results.keys())
    avg_rewards = [results[alg]['avg_reward'] for alg in algorithms]
    std_rewards = [results[alg]['std_reward'] for alg in algorithms]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Bar plot of average rewards
    bars = ax1.bar(algorithms, avg_rewards, yerr=std_rewards, 
                   capsize=5, color=['skyblue', 'lightgreen', 'orange'], alpha=0.7)
    ax1.set_ylabel('Average Reward')
    ax1.set_title('Algorithm Performance Comparison')
    ax1.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, avg, std in zip(bars, avg_rewards, std_rewards):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + std/2,
                f'{avg:.3f}', ha='center', va='bottom')
    
    # Box plot of reward distributions
    reward_distributions = [results[alg]['episode_rewards'] for alg in algorithms]
    ax2.boxplot(reward_distributions, labels=algorithms)
    ax2.set_ylabel('Episode Reward')
    ax2.set_title('Reward Distribution by Algorithm')
    ax2.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Run algorithm comparison
print("Starting algorithm comparison...")
print("This will train PPO, DQN, and A2C algorithms and compare their performance.")
print("Training 3 algorithms for 20,000 timesteps each...")

comparison_results, comparison_models = train_and_compare_algorithms()
plot_algorithm_comparison(comparison_results)

print("\nAlgorithm Comparison Summary:")
for alg, result in comparison_results.items():
    print(f"{alg}: {result['avg_reward']:.3f} ± {result['std_reward']:.3f}")

# Find best performing algorithm
best_algorithm = max(comparison_results.keys(), 
                    key=lambda x: comparison_results[x]['avg_reward'])
print(f"\nBest performing algorithm: {best_algorithm}")
print(f"Best average reward: {comparison_results[best_algorithm]['avg_reward']:.3f}")

In [None]:
# Summary and Usage Instructions

print("🎉 Quantum Network RL Training Setup Complete!")
print("=" * 60)

print("\nWhat we've built:")
print("✓ Professional Gymnasium environment for quantum network protocol selection")
print("✓ Integration with Stable-Baselines3 for state-of-the-art RL algorithms")
print("✓ Training pipeline with evaluation callbacks and monitoring")
print("✓ Comprehensive analysis and visualization tools")
print("✓ Comparison framework for multiple RL algorithms (PPO, DQN, A2C)")

print("\n📋 How to use this setup:")
print("1. Run the environment test cell to verify everything works")
print("2. Train your preferred algorithm (PPO recommended)")
print("3. Analyze the trained agent's performance with visualization")
print("4. Compare different algorithms if needed")
print("5. Use the best model for your quantum network protocol selection!")

print("\n🔧 Key Features:")
print("• Input: 3 theta values + 3 Fisher information values")
print("• Output: Selection of 1 protocol from 6 available options")
print("• Reward: Sum of Fisher information values for optimal quantum sensing")
print("• Algorithms: PPO, DQN, A2C (easily extensible)")

print("\n📁 Files created:")
print("• quantum_network_ppo_final.zip - Trained PPO model")
print("• quantum_network_dqn.zip - Trained DQN model") 
print("• quantum_network_a2c.zip - Trained A2C model")
print("• Tensorboard logs in ./ppo_quantum_tensorboard/")
print("• Best model checkpoints in ./best_quantum_model/")

print("\n🚀 Next Steps:")
print("• Experiment with different hyperparameters")
print("• Extend to multi-step protocol selection")
print("• Add more sophisticated reward functions")
print("• Integrate with real quantum hardware simulations")

print("\nHappy quantum computing! 🔬⚛️")