# Policy Iteration Algorithm Demo

Welcome to the interactive Policy Iteration demonstration! This tool allows you to explore how the Policy Iteration algorithm works in a GridWorld environment.

## How to Use This Demo

1. **Use the slider below** to control the number of policy iteration steps
2. **Watch the policy evolve** as you increase the number of iterations
3. **Observe convergence** - notice how the policy stabilizes after several iterations

## What You're Seeing

- **GridWorld**: A 4×3 grid where an agent can move up, down, left, or right
- **Goal States**: Green cell (+1 reward) and red cell (-1 reward) in the top-right area
- **Blocked State**: Gray cell that the agent cannot enter
- **Policy**: Arrows in each cell show the best action to take from that state
- **Convergence**: Policy stabilizes as the algorithm finds the optimal solution

## Key Concepts to Observe

- How the **policy changes** with each iteration
- The **alternating steps** of policy evaluation and policy improvement
- **Convergence behavior** - when does the policy stop changing?
- **Difference from Value Iteration** - direct policy optimization vs. value-based approach

In [None]:
from gridworld import GridWorld
from policy_iteration import PolicyIteration
from tabular_policy import TabularPolicy
import ipywidgets as widgets
from IPython.display import display, clear_output


gridworld = GridWorld()
# Create an interactive slider widget
iterations_slider = widgets.IntSlider(value=0, min=0, max=100, step=1, description='Iterations:')
display(iterations_slider)

# Create an output widget for displaying the plot
output_widget = widgets.Output()
display(output_widget)

# Define a function to perform policy iteration and display the policy
def perform_policy_iteration_and_display(iterations):
    policy = TabularPolicy(default_action=gridworld.LEFT)
    PolicyIteration(gridworld, policy).policy_iteration(max_iterations=iterations)
    
    # Clear the previous plot and display the new one
    with output_widget:
        clear_output(wait=True)
        gridworld.visualise_policy(policy, f"Policy after {iterations} iterations")
        


# Define a callback function to update the visualization based on the slider value
def slider_callback(change):
    iterations = change.new
    perform_policy_iteration_and_display(iterations)

# Connect the slider to the callback function
iterations_slider.observe(slider_callback, 'value')

# Initially perform policy iteration with 0 iterations
perform_policy_iteration_and_display(0)

IntSlider(value=0, description='Iterations:')

Output()