# Practical Task: Implementing RRT in 2D Navigation
In this task, you'll implement parts of the Rapidly-Exploring Random Tree (RRT) algorithm to navigate a point robot from a start position to a goal position in a 2D environment with obstacles.

The task is divided into sections, and by the end, you should have a basic RRT planner.

## Problem Setup
- A 2D environment with obstacles
- A start position and a goal position
- Your goal is to navigate from start to goal using RRT


### Step 1: Define the Environment
The environment consists of a grid with obstacles. Run the cell below to visualize it.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [None]:
# Environment setup
grid_size = 100
obstacles = [
    ((20, 20), 10),
    ((60, 60), 15),
    ((40, 80), 10),
    ((80, 30), 10)
]
start_pos = (5, 5)
goal_pos = (90, 90)
step_size = 5

In [None]:
# Plot setup
fig, ax = plt.subplots(figsize=(6, 6))
plt.xlim(0, grid_size)
plt.ylim(0, grid_size)

# Plot obstacles
for (center, radius) in obstacles:
    circle = plt.Circle(center, radius, color='gray')
    ax.add_patch(circle)

# Plot start and goal positions
ax.plot(start_pos[0], start_pos[1], 'go', label='Start')
ax.plot(goal_pos[0], goal_pos[1], 'ro', label='Goal')
ax.legend()
plt.show()

### Step 2: Define Key Functions for RRT
1. **Sampling function** - Generates random points within the grid.
2. **Distance function** - Computes the Euclidean distance between two points.
3. **Collision check** - Ensures that new points do not collide with obstacles. (already implemented)


In [None]:
# Sampling function
def sample_point():
    # TODO: sample a point on a grid (integer values) f
    return 10, 10

# Distance function
def distance(p1, p2):
    # TODO: compute distance between two points
    return 0.

# Collision check
def is_collision(point):
    for (center, radius) in obstacles:
        if distance(point, center) < radius:
            return True
    return False

### Step 3: Implement RRT Expansion
The RRT algorithm grows a tree by sampling random points and extending the nearest point in the tree toward the sample, up to a step size.

In this step, complete the `extend` function, which should:
- Find the closest node in the tree
- Extend a new node towards the sample (up to a step size of `5`)
- Check if the new node collides with any obstacles
- Return the new node if it's valid, otherwise discard it


In [None]:
# Finish function that adds nodes to the tree
def extend(tree, sample):
    # Find the nearest node in the tree
    nearest_node = min(tree.keys(), key=lambda node: distance(node, sample))
    # TODO: finish the extend function

    return None

### Step 4: Run the RRT
In this step, we'll run RRT to find a path. The algorithm will stop if it reaches the goal within a distance of `10` units.

Run the cell to complete the task and visualize the result.

In [None]:
# Initialize the RRT tree with the starting position
tree = {start_pos: None}
goal_reached = False

path = []
print(f"Tree Nodes: {len(list(tree.keys()))}")  # Only for debugging

In [None]:
lines = []
def update(frame):
    global goal_reached, path
    
    if goal_reached:
        return lines  # Stop updating once the goal is reached

    # Extend tree and add a new node
    sample = sample_point()
    new_node = extend(tree, sample)

    if new_node:
        parent = tree[new_node]
        line, = ax.plot([new_node[0], parent[0]], [new_node[1], parent[1]], 'b.-')
        lines.append(line)
        
        # Check if goal is reached
        if distance(new_node, goal_pos) < 10:
            tree[goal_pos] = new_node
            goal_reached = True
            
            # Backtrack to find the path
            node = goal_pos
            while node is not None:
                path.append(node)
                node = tree[node]
            path = path[::-1]  # Reverse the path for correct order
            
            # Plot the final path
            path_x, path_y = zip(*path)
            ax.plot(path_x, path_y, 'r-', linewidth=2, label='Path to Goal')
            ax.legend()

    return lines
    
def clear_plot():
    ax.clear()  # Clear previous frames
    # Replot obstacles, start and goal
    for (center, radius) in obstacles:
        circle = plt.Circle(center, radius, color='gray')
        ax.add_patch(circle)
    ax.plot(start_pos[0], start_pos[1], 'go', label='Start')
    ax.plot(goal_pos[0], goal_pos[1], 'ro', label='Goal')
    return lines
    
# Create animation
ani = FuncAnimation(fig, update, init_func=clear_plot, frames=200, blit=True, repeat=False)

In [None]:
from IPython.display import HTML
lines = []
HTML(ani.to_jshtml())