# AI Clinician with LD2Z Scheduler

This notebook demonstrates the implementation of Q-learning for medical decision-making (AI Clinician approach) using the LD2Z learning rate scheduler.

## Background

- **AI Clinician**: A reinforcement learning approach for optimal treatment strategies in intensive care (sepsis management)
- **LD2Z Scheduler**: Learning rate schedule proportional to 1/t^(2/3), which provides optimal convergence properties for Q-learning
- **Goal**: Analyze the statistical properties and convergence behavior of Q-learning with LD2Z scheduler

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from q_learning_agent import QLearningAgent
from environment import SimpleMedicalMDP
from statistical_analysis import StatisticalAnalyzer
from ld2z_scheduler import LD2ZScheduler

# Set random seed for reproducibility
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. LD2Z Scheduler Analysis

First, let's visualize the LD2Z learning rate schedule and compare it with other common schedules.

In [None]:
# Create schedulers
ld2z = LD2ZScheduler(initial_lr=1.0, exponent=2/3)
constant = LD2ZScheduler(initial_lr=0.1, exponent=0)  # Constant
standard = LD2ZScheduler(initial_lr=1.0, exponent=1.0)  # 1/t schedule

# Generate learning rates
steps = np.arange(1, 1001)
lr_ld2z = [ld2z.get_learning_rate(t) for t in steps]
lr_constant = [0.1 for t in steps]
lr_standard = [standard.get_learning_rate(t) for t in steps]

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Linear scale
axes[0].plot(steps, lr_ld2z, label='LD2Z (1/t^(2/3))', linewidth=2)
axes[0].plot(steps, lr_standard, label='Standard (1/t)', linewidth=2)
axes[0].plot(steps, lr_constant, label='Constant (0.1)', linewidth=2)
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Learning Rate')
axes[0].set_title('Learning Rate Schedules (Linear Scale)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Log scale
axes[1].plot(steps, lr_ld2z, label='LD2Z (1/t^(2/3))', linewidth=2)
axes[1].plot(steps, lr_standard, label='Standard (1/t)', linewidth=2)
axes[1].plot(steps, lr_constant, label='Constant (0.1)', linewidth=2)
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedules (Log Scale)')
axes[1].set_xscale('log')
axes[1].set_yscale('log')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("LD2Z Schedule Properties:")
print(f"  Initial learning rate (t=1): {lr_ld2z[0]:.4f}")
print(f"  Learning rate at t=100: {lr_ld2z[99]:.4f}")
print(f"  Learning rate at t=1000: {lr_ld2z[999]:.4f}")
print(f"\nDecay is slower than 1/t but faster than constant")

## 2. Create Medical Decision-Making Environment

We create a simplified MDP that simulates medical interventions.

In [None]:
# Create environment
n_states = 50
n_actions = 5
env = SimpleMedicalMDP(n_states=n_states, n_actions=n_actions, seed=42)

print(f"Environment created with {n_states} states and {n_actions} actions")
print(f"Number of terminal states: {len(env.terminal_states)}")

# Visualize reward structure
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap of rewards
im1 = axes[0].imshow(env.R, aspect='auto', cmap='RdYlGn')
axes[0].set_xlabel('Action')
axes[0].set_ylabel('State')
axes[0].set_title('Reward Structure R(s, a)')
plt.colorbar(im1, ax=axes[0])

# Distribution of rewards
axes[1].hist(env.R.flatten(), bins=30, alpha=0.7, edgecolor='black')
axes[1].set_xlabel('Reward')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Distribution of Immediate Rewards')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 3. Train Q-Learning Agent with LD2Z Scheduler

Now we train the agent using the LD2Z scheduler.

In [None]:
# Initialize agent with LD2Z scheduler
agent_ld2z = QLearningAgent(
    n_states=n_states,
    n_actions=n_actions,
    gamma=0.99,
    initial_lr=1.0,
    use_ld2z=True
)

# Initialize analyzer
analyzer_ld2z = StatisticalAnalyzer()

# Training parameters
n_episodes = 1000
max_steps_per_episode = 100
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 0.995

print("Training Q-Learning Agent with LD2Z Scheduler...")
print(f"Episodes: {n_episodes}, Max Steps: {max_steps_per_episode}")
print("-" * 60)

epsilon = epsilon_start
episode_rewards_ld2z = []

for episode in range(n_episodes):
    state = env.reset()
    episode_reward = 0
    
    for step in range(max_steps_per_episode):
        # Select and take action
        action = agent_ld2z.select_action(state, epsilon)
        next_state, reward, done, info = env.step(action)
        
        # Update Q-values
        td_error = agent_ld2z.update(state, action, reward, next_state, done)
        
        # Record statistics
        lr = agent_ld2z.scheduler.get_learning_rate(
            agent_ld2z.visit_counts[state, action]
        )
        analyzer_ld2z.record_step(agent_ld2z.Q, td_error, reward, lr)
        
        episode_reward += reward
        state = next_state
        
        if done:
            break
    
    episode_rewards_ld2z.append(episode_reward)
    epsilon = max(epsilon_end, epsilon * epsilon_decay)
    
    if (episode + 1) % 100 == 0:
        avg_reward = np.mean(episode_rewards_ld2z[-100:])
        print(f"Episode {episode + 1}: Avg Reward = {avg_reward:.2f}, Epsilon = {epsilon:.3f}")

print("-" * 60)
print("Training completed!")

## 4. Train Agent with Constant Learning Rate (for comparison)

In [None]:
# Reset environment
env = SimpleMedicalMDP(n_states=n_states, n_actions=n_actions, seed=42)

# Initialize agent with constant learning rate
agent_const = QLearningAgent(
    n_states=n_states,
    n_actions=n_actions,
    gamma=0.99,
    initial_lr=0.1,
    use_ld2z=False
)

analyzer_const = StatisticalAnalyzer()

print("Training Q-Learning Agent with Constant Learning Rate...")
print("-" * 60)

epsilon = epsilon_start
episode_rewards_const = []

for episode in range(n_episodes):
    state = env.reset()
    episode_reward = 0
    
    for step in range(max_steps_per_episode):
        action = agent_const.select_action(state, epsilon)
        next_state, reward, done, info = env.step(action)
        td_error = agent_const.update(state, action, reward, next_state, done)
        
        analyzer_const.record_step(agent_const.Q, td_error, reward, agent_const.constant_lr)
        
        episode_reward += reward
        state = next_state
        
        if done:
            break
    
    episode_rewards_const.append(episode_reward)
    epsilon = max(epsilon_end, epsilon * epsilon_decay)
    
    if (episode + 1) % 100 == 0:
        avg_reward = np.mean(episode_rewards_const[-100:])
        print(f"Episode {episode + 1}: Avg Reward = {avg_reward:.2f}")

print("-" * 60)
print("Training completed!")

## 5. Statistical Analysis and Comparison

Now we analyze and compare the statistical properties of both approaches.

In [None]:
# Generate reports
print("\n" + "=" * 60)
print("STATISTICAL ANALYSIS RESULTS")
print("=" * 60)

print("\n1. LD2Z Scheduler:")
print(analyzer_ld2z.generate_report())

print("\n2. Constant Learning Rate:")
print(analyzer_const.generate_report())

### 5.1 Convergence Comparison

In [None]:
# Plot convergence comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# TD Error comparison
axes[0, 0].plot(analyzer_ld2z.history['td_errors'], alpha=0.6, label='LD2Z')
axes[0, 0].plot(analyzer_const.history['td_errors'], alpha=0.6, label='Constant LR')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('TD Error')
axes[0, 0].set_title('TD Error Convergence')
axes[0, 0].set_yscale('log')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Episode rewards comparison
window = 50
rewards_ld2z_ma = np.convolve(episode_rewards_ld2z, np.ones(window)/window, mode='valid')
rewards_const_ma = np.convolve(episode_rewards_const, np.ones(window)/window, mode='valid')

axes[0, 1].plot(rewards_ld2z_ma, label='LD2Z', linewidth=2)
axes[0, 1].plot(rewards_const_ma, label='Constant LR', linewidth=2)
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Reward (Moving Average)')
axes[0, 1].set_title('Training Rewards')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Value function changes
axes[1, 0].plot(analyzer_ld2z.history['value_differences'], alpha=0.6, label='LD2Z')
axes[1, 0].plot(analyzer_const.history['value_differences'], alpha=0.6, label='Constant LR')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Value Function Change')
axes[1, 0].set_title('Value Function Convergence')
axes[1, 0].set_yscale('log')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Q-value distributions
axes[1, 1].hist(agent_ld2z.Q.flatten(), bins=30, alpha=0.6, label='LD2Z', edgecolor='black')
axes[1, 1].hist(agent_const.Q.flatten(), bins=30, alpha=0.6, label='Constant LR', edgecolor='black')
axes[1, 1].set_xlabel('Q-value')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title('Final Q-value Distribution')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### 5.2 Detailed Convergence Plots

In [None]:
# LD2Z detailed convergence
print("LD2Z Scheduler - Detailed Convergence Analysis:")
fig_ld2z = analyzer_ld2z.plot_convergence()
plt.show()

# Constant LR detailed convergence
print("\nConstant Learning Rate - Detailed Convergence Analysis:")
fig_const = analyzer_const.plot_convergence()
plt.show()

## 6. Policy Analysis

Compare the learned policies.

In [None]:
# Get learned policies
policy_ld2z = agent_ld2z.get_policy()
policy_const = agent_const.get_policy()

# Compute value functions
values_ld2z = agent_ld2z.get_value_function()
values_const = agent_const.get_value_function()

# Visualize policies
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# LD2Z policy
axes[0, 0].bar(range(n_states), policy_ld2z, alpha=0.7)
axes[0, 0].set_xlabel('State')
axes[0, 0].set_ylabel('Action')
axes[0, 0].set_title('LD2Z Learned Policy')
axes[0, 0].set_ylim([-0.5, n_actions - 0.5])
axes[0, 0].grid(True, alpha=0.3)

# Constant LR policy
axes[0, 1].bar(range(n_states), policy_const, alpha=0.7, color='orange')
axes[0, 1].set_xlabel('State')
axes[0, 1].set_ylabel('Action')
axes[0, 1].set_title('Constant LR Learned Policy')
axes[0, 1].set_ylim([-0.5, n_actions - 0.5])
axes[0, 1].grid(True, alpha=0.3)

# LD2Z value function
axes[1, 0].plot(values_ld2z, marker='o', alpha=0.7)
axes[1, 0].set_xlabel('State')
axes[1, 0].set_ylabel('Value')
axes[1, 0].set_title('LD2Z Value Function')
axes[1, 0].grid(True, alpha=0.3)

# Constant LR value function
axes[1, 1].plot(values_const, marker='o', alpha=0.7, color='orange')
axes[1, 1].set_xlabel('State')
axes[1, 1].set_ylabel('Value')
axes[1, 1].set_title('Constant LR Value Function')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Policy agreement
policy_agreement = np.mean(policy_ld2z == policy_const)
print(f"\nPolicy Agreement: {policy_agreement * 100:.2f}%")
print(f"Mean value (LD2Z): {np.mean(values_ld2z):.4f}")
print(f"Mean value (Constant): {np.mean(values_const):.4f}")

## 7. Statistical Properties Summary

Key findings from the analysis:

In [None]:
metrics_ld2z = analyzer_ld2z.compute_convergence_metrics()
metrics_const = analyzer_const.compute_convergence_metrics()

print("\n" + "=" * 60)
print("SUMMARY OF STATISTICAL PROPERTIES")
print("=" * 60)

print("\n1. Convergence Speed:")
if 'convergence_ratio' in metrics_ld2z:
    print(f"   LD2Z: {metrics_ld2z['convergence_ratio']:.4f}")
if 'convergence_ratio' in metrics_const:
    print(f"   Constant: {metrics_const['convergence_ratio']:.4f}")
print("   (Lower is better - indicates faster convergence)")

print("\n2. Final TD Error:")
print(f"   LD2Z: {metrics_ld2z.get('mean_recent_td_error', 0):.6f}")
print(f"   Constant: {metrics_const.get('mean_recent_td_error', 0):.6f}")
print("   (Lower is better)")

print("\n3. Reward Performance:")
print(f"   LD2Z: {metrics_ld2z.get('mean_reward', 0):.4f} ± {metrics_ld2z.get('std_reward', 0):.4f}")
print(f"   Constant: {metrics_const.get('mean_reward', 0):.4f} ± {metrics_const.get('std_reward', 0):.4f}")

print("\n4. Value Function Stability:")
print(f"   LD2Z: {metrics_ld2z.get('mean_value_change', 0):.6f}")
print(f"   Constant: {metrics_const.get('mean_value_change', 0):.6f}")
print("   (Lower indicates more stable convergence)")

print("\n" + "=" * 60)
print("\nKey Insights:")
print("- LD2Z scheduler provides adaptive learning rates")
print("- Learning rate decays as 1/t^(2/3), balancing exploration and convergence")
print("- Theoretical guarantees for optimal convergence in Q-learning")
print("- Particularly effective for medical decision-making with sparse rewards")
print("=" * 60)

## 8. Conclusion

This notebook demonstrated:
1. Implementation of Q-learning with LD2Z scheduler for AI Clinician application
2. Comparison with constant learning rate baseline
3. Statistical analysis of convergence properties
4. Visualization of learning dynamics and policy quality

The LD2Z scheduler shows promise for medical reinforcement learning applications where:
- Optimal convergence guarantees are important
- State-action spaces can be large
- Sample efficiency is critical