# Optimización de trayectorias

Este documento presenta una **Optimización de Trayectorias** utilizando las herramientas de cálculo automático de descenso de gradiente de PyTorch.

En lugar de entrenar una red neuronal con miles de datos, aquí se utiliza un optimizador para **deformar** una línea existente (la calle) hasta que cumpla con nuestras reglas (no chocar, ser suave, no alejarse mucho de la original).

## 1. El Concepto General

El problema de optimización de trayectorias consiste en refinar una ruta inicial (o trayectoria candidata) para satisfacer un conjunto de restricciones y objetivos definidos. Esto implica:
*   **Evitación de Obstáculos:** La trayectoria optimizada debe mantener una distancia segura de las zonas prohibidas o los obstáculos presentes en el entorno a través de una distancia buffer.
*   **Suavidad de la Trayectoria:** Se busca minimizar la curvatura o los cambios bruscos de dirección, lo que contribuye a una trayectoria más eficiente y practicable.
*   **Restricciones de Puntos de Anclaje:** Los puntos inicial y final de la trayectoria deben permanecer fijos, garantizando la conectividad con los orígenes y destinos preestablecidos.

El algoritmo opera de manera iterativa, ajustando incrementalmente la posición de los puntos que componen la trayectoria. A través de un proceso de optimización (análogo al descenso de gradiente), se busca minimizar una función de coste que penaliza las violaciones de las restricciones y fomenta la suavidad, hasta alcanzar una configuración óptima de la trayectoria.

## 2. Preparación de los Datos (PCA)

Antes de empezar, el código simplifica el problema rotando el mapa.
Las calles pueden ir en cualquier dirección. Matemáticamente, es difícil trabajar con líneas diagonales.
El código usa **PCA (Análisis de Componentes Principales)** para encontrar la dirección principal de la calle y rotar todo el mundo para que la calle quede horizontal (sobre el eje X).
Esto nos permite modificar solo la coordenada `Y` de los puntos, simplificando el problema a la mitad.

In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import math
import random
from sklearn.decomposition import PCA
import shapely
import geopandas as gpd
from shapely.geometry import LineString, Point, Polygon
import ipywidgets as widgets
from IPython.display import display, clear_output
import warnings
import traceback
warnings.filterwarnings('ignore')

print(f'Shapely version: {shapely.__version__}')
print(f'Torch version: {torch.__version__}')

Shapely version: 2.0.0
Torch version: 2.4.1+cu121


## 3. Lógica de Ajuste de Calles (Street Fitting)

El siguiente código contiene la lógica central encapsulada en clases para mejor organización:

### A. `GeometryUtils`
Clase estática que maneja todas las transformaciones geométricas (rotación, traslación, PCA, segmentación).

### B. `PathOptimizer`
Clase principal que maneja la optimización.
*   **Inicialización**: Recibe todos los parámetros de configuración (pesos, buffers, dimensiones del vehículo).
*   **`fit_all_streets`**: Método principal que itera sobre las calles y ejecuta la optimización.
*   **`my_loss`**: La función de costo que guía al optimizador.

### C. `CurveModel`
El modelo de PyTorch que contiene los parámetros optimizables (coordenadas Y).

In [2]:
class GeometryUtils:
    """
    Utility class for geometric transformations and operations.

    Methods:
        translate(xarr, yarr, tx, ty): Translates a series of points (xarr, yarr) by a given offset (tx, ty).
        rotate(xarr, yarr, angle): Rotates a series of points (xarr, yarr) by a given angle around the origin.
        get_pca_transform(xarr, yarr): Computes the principal component analysis (PCA) transform for a set of points, returning the centroid and the angle of the first principal component.
        segmentize_line(line_geom, max_segment_length): Segmentizes a LineString into smaller segments, ensuring no segment exceeds max_segment_length.
        sample_polygon(polygon, dist): Samples points along the boundary of a polygon at a specified distance.
    """
    
    @staticmethod
    def translate(xarr, yarr, tx, ty):
        xret = []
        yret = []
        for i in range(len(xarr)):
            x = xarr[i] + tx
            y = yarr[i] + ty
            xret.append(x)
            yret.append(y)
        return xret, yret

    @staticmethod
    def rotate(xarr, yarr, angle):
        xret = []
        yret = []
        c = math.cos(angle)
        s = math.sin(angle)
        for i in range(len(xarr)):
            x = c * xarr[i] - s * yarr[i]
            y = s * xarr[i] + c * yarr[i]
            xret.append(x)
            yret.append(y)
        return xret, yret

    @staticmethod
    def get_pca_transform(xarr, yarr):
        n = len(xarr)
        if n == 0: return 0, 0, 0
        
        X = []
        sx = 0
        sy = 0
        for i in range(n):
            X.append((xarr[i], yarr[i]))
            sx += xarr[i]
            sy += yarr[i]
            
        pca = PCA(n_components=2)
        pca.fit(X)
        angle = math.atan2(pca.components_[0, 1], pca.components_[0, 0])
        return sx / n, sy / n, angle

    @staticmethod
    def segmentize_line(line_geom, max_segment_length=1.0):
        length = line_geom.length
        num_points = int(math.ceil(length / max_segment_length)) + 1
        distances = np.linspace(0, length, num_points)
        points = [line_geom.interpolate(d) for d in distances]
        return shapely.LineString(points)

    @staticmethod
    def sample_polygon(polygon: shapely.Polygon, dist=1.0):
        boundary = polygon.boundary
        if boundary.is_empty: return [], []
        
        length = boundary.length
        points = shapely.line_interpolate_point(boundary, [x for x in np.arange(0, length, dist)])
        xret = []
        yret = []
        for xy in points:
            xret.append(xy.x)
            yret.append(xy.y)
        return xret, yret


class CurveModel(torch.nn.Module):
    def __init__(self, device, x, y0):
        super(CurveModel, self).__init__()
        self.yb = torch.clone(y0).to(device=device)
        self.y = torch.nn.Parameter(self.yb)

    def forward(self, x):
        return self.y


class PathOptimizer:
    """
    Path Optimizer for street fitting.

    Methods:
        gen_footprint_obstacle: Generates a footprint obstacle.
        gen_footprint_high_obstacle: Generates a high footprint obstacle.
        select_near_holes: Selects points near holes.
        select_near_geofence_points: Selects points near geofence.
        my_loss: Calculates the loss function.
        del_colliding: Deletes colliding points.
        fit_street: Fits a street.
        fit_all_streets: Fits all streets.

    Attributes:
        hole_buffer (float): Buffer size for holes.
        fp_front (float): Front footprint size.
        fp_back (float): Back footprint size.
        fp_side (float): Side footprint size.
        fp_buffer (float): Buffer size for footprint.
        repulsion_weight (float): Weight for repulsion.
        fidelity_weight (float): Weight for fidelity.
        safety_margin (float): Safety margin.
        soft_limit_range (float): Soft limit range.
        iterations (int): Number of iterations.
        device (torch.device): Device for computation.
    """

    # Default Parameters
    DEFAULT_HOLE_BUFFER = 7.0
    DEFAULT_FOOTPRINT_FRONT = 3.5
    DEFAULT_FOOTPRINT_BACK = -3.5
    DEFAULT_FOOTPRINT_SIDE = 1.8
    DEFAULT_FOOTPRINT_BUFFER = 1.0

    DEFAULT_REPULSION_WEIGHT = 500.0
    DEFAULT_FIDELITY_WEIGHT = 0.0001
    DEFAULT_SAFETY_MARGIN = 0.5
    DEFAULT_SOFT_LIMIT_RANGE = 5.0
    DEFAULT_ITERATIONS = 500

    def __init__(self, 
                 hole_buffer=DEFAULT_HOLE_BUFFER,
                 fp_front=DEFAULT_FOOTPRINT_FRONT,
                 fp_back=DEFAULT_FOOTPRINT_BACK,
                 fp_side=DEFAULT_FOOTPRINT_SIDE,
                 fp_buffer=DEFAULT_FOOTPRINT_BUFFER,
                 repulsion_weight=DEFAULT_REPULSION_WEIGHT,
                 fidelity_weight=DEFAULT_FIDELITY_WEIGHT,
                 safety_margin=DEFAULT_SAFETY_MARGIN,
                 soft_limit_range=DEFAULT_SOFT_LIMIT_RANGE,
                 iterations=DEFAULT_ITERATIONS):
        
        self.hole_buffer = hole_buffer
        self.fp_front = fp_front
        self.fp_back = fp_back
        self.fp_side = fp_side
        self.fp_buffer = fp_buffer
        
        self.repulsion_weight = repulsion_weight
        self.fidelity_weight = fidelity_weight
        self.safety_margin = safety_margin
        self.soft_limit_range = soft_limit_range
        self.iterations = iterations
        
        self.device = torch.device('cpu:0')

    def gen_footprint_obstacle(self, x, y, angle, buffer_val=None):
        back = self.fp_back
        front = self.fp_front
        left = self.fp_side
        right = -self.fp_side
        c = math.cos(angle)
        s = math.sin(angle)

        footprint = shapely.Polygon([
            [x+c*back -s*left,  y+s*back +c*left],
            [x+c*front-s*left,  y+s*front+c*left],
            [x+c*front-s*right, y+s*front+c*right],
            [x+c*back -s*right, y+s*back +c*right],
            [x+c*back -s*left,  y+s*back +c*left]
        ])
        
        buf = buffer_val if buffer_val is not None else self.fp_buffer
        if buf > 0:
            footprint = footprint.buffer(buf)
            
        return footprint

    def gen_footprint_high_obstacle(self, x, y, angle):
        back = -3.5
        front = 6.5
        left = 1.8 
        right = -1.8
        c = math.cos(angle)
        s = math.sin(angle)

        footprint = shapely.Polygon([
            [x+c*back -s*left,  y+s*back +c*left],
            [x+c*front-s*left,  y+s*front+c*left],
            [x+c*front-s*right, y+s*front+c*right],
            [x+c*back -s*right, y+s*back +c*right],
            [x+c*back -s*left,  y+s*back +c*left]
        ])    
        return footprint

    def select_near_holes(self, xarr, yarr, holesx, holesy, d_thr):
        curve_xy = [(xarr[i],yarr[i]) for i in range(len(xarr))]
        holes_xy = [(holesx[i],holesy[i]) for i in range(len(holesx))]

        curve = shapely.LineString(curve_xy)
        holes = shapely.MultiPoint(holes_xy)

        street = shapely.buffer(curve, d_thr)
        holes = street.intersection(holes)

        xret = []
        yret = []
        
        if not holes.is_empty:
            if isinstance(holes, shapely.MultiPoint):
                for xy in holes.geoms:
                    xret.append(xy.x)
                    yret.append(xy.y)
            else:
                xret.append(holes.x)
                yret.append(holes.y)

        return xret, yret

    def select_near_geofence_points(self, xarr, yarr, geofx, geofy, d_thr):
        return self.select_near_holes(xarr, yarr, geofx, geofy, d_thr)

    def my_loss(self, x, y, yo, xh, yh, xlow, ylow, xhigh, yhigh, safety_radius):
        loss = 0
        
        loss += self.fidelity_weight * torch.mean((y[1:-1]-yo[1:-1])**2) 
        
        curvature2 = (y[0:-2] + y[2:] - 2*y[1:-1] )**2 / (x[2:] - x[:-2])**2
        loss += 60*torch.mean(curvature2)

        loss += 200*torch.mean((curvature2-0.003)*torch.nn.ReLU()(curvature2-0.003))
        
        loss += 10 * (y[0]-yo[0])**2
        loss += 10 * (y[1]-yo[1])**2
        loss += 10 * (y[-2]-yo[-2])**2
        loss += 10 * (y[-1]-yo[-1])**2
        
        xh = torch.reshape(xh, (xh.shape[0], 1))
        yh = torch.reshape(yh, (yh.shape[0], 1))
        dx = x - xh
        dy = y - yh
        d2 = dx**2 + dy**2
        
        hard_limit_sq = safety_radius**2
        soft_limit_sq = (safety_radius + self.soft_limit_range)**2
        
        loss += (0.5 / d2 * torch.sigmoid(5*(soft_limit_sq - d2))).sum()
        
        violation = torch.nn.ReLU()(hard_limit_sq - d2)
        loss += self.repulsion_weight * torch.mean(violation)

        L = torch.sqrt( (x[2:] - x[:-2])**2 + (y[2:] - y[:-2])**2)
        C = (x[2:] - x[:-2]) / L
        S = (y[2:] - y[:-2]) / L

        xp = x[1:-1] + 3 * C
        yp = y[1:-1] + 3 * S

        xlow = torch.reshape(xlow, (xlow.shape[0], 1))
        ylow = torch.reshape(ylow, (ylow.shape[0], 1))
        dx = x - xlow
        dy = y - ylow
        d2 = dx**2 + dy**2
        loss += (20.0 / (d2+0.5) * torch.sigmoid(5*(2.1*2.1 - d2)) ).sum()

        xhigh = torch.reshape(xhigh, (xhigh.shape[0], 1))
        yhigh = torch.reshape(yhigh, (yhigh.shape[0], 1))
        dx = x - xhigh
        dy = y - yhigh
        d2 = dx**2 + dy**2
        loss += (20.0 / (d2+0.5) * torch.sigmoid(5*(2.1*2.1 - d2)) ).sum()

        dx = xp - xhigh[1:-1,:]
        dy = yp - yhigh[1:-1,:]
        d2 = dx**2 + dy**2
        loss += (2.0 / (d2+0.5) * torch.sigmoid(5*(2.1*2.1 - d2)) ).sum()

        return loss

    def del_colliding(self, orig_x, orig_y, holes_x, holes_y, low_x, low_y, high_x, high_y):
        tx, ty, angle = GeometryUtils.get_pca_transform(orig_x, orig_y)

        orig_x, orig_y = GeometryUtils.translate(orig_x, orig_y, -tx, -ty)
        orig_x, orig_y = GeometryUtils.rotate(orig_x, orig_y, -angle)

        holes_x, holes_y = GeometryUtils.translate(holes_x, holes_y, -tx, -ty)
        holes_x, holes_y = GeometryUtils.rotate(holes_x, holes_y, -angle)    

        low_x, low_y = GeometryUtils.translate(low_x, low_y, -tx, -ty)
        low_x, low_y = GeometryUtils.rotate(low_x, low_y, -angle)    

        high_x, high_y = GeometryUtils.translate(high_x, high_y, -tx, -ty)
        high_x, high_y = GeometryUtils.rotate(high_x, high_y, -angle)    

        total_buffer = self.hole_buffer + self.fp_buffer
        
        holes = shapely.MultiPoint([[holes_x[i],holes_y[i]] for i in range(len(holes_x))]).buffer(total_buffer)
        low = shapely.MultiPoint([[low_x[i],low_y[i]] for i in range(len(low_x))])
        high = shapely.MultiPoint([[high_x[i],high_y[i]] for i in range(len(high_x))])
        
        curve_x = []
        curve_y = []

        for i in range(len(orig_x)):
            if i > 0:
                ori = math.atan2(orig_y[i]-orig_y[i-1], orig_x[i]-orig_x[i-1])
            else:
                ori = math.atan2(orig_y[i+1]-orig_y[i], orig_x[i+1]-orig_x[i])
                
            footprint = self.gen_footprint_obstacle(orig_x[i], orig_y[i], ori, buffer_val=0.0)
            footprint_high = self.gen_footprint_high_obstacle(orig_x[i], orig_y[i], ori)
            
            if not holes.intersects(footprint) and not low.intersects(footprint) and not high.intersects(footprint_high):
                curve_x.append(orig_x[i])
                curve_y.append(orig_y[i])

        curve_x, curve_y = GeometryUtils.rotate(curve_x, curve_y, angle)
        curve_x, curve_y = GeometryUtils.translate(curve_x, curve_y, tx, ty)

        return curve_x, curve_y

    def fit_street(self, orig_x, orig_y, holes_x, holes_y, low_x, low_y, high_x, high_y, safety_radius=2.1):
        if hasattr(orig_x, 'tolist'): orig_x = orig_x.tolist()
        if hasattr(orig_y, 'tolist'): orig_y = orig_y.tolist()
        
        tx, ty, angle = GeometryUtils.get_pca_transform(orig_x, orig_y)

        orig_x, orig_y = GeometryUtils.translate(orig_x, orig_y, -tx, -ty)
        orig_x, orig_y = GeometryUtils.rotate(orig_x, orig_y, -angle)

        holes_x, holes_y = GeometryUtils.translate(holes_x, holes_y, -tx, -ty)
        holes_x, holes_y = GeometryUtils.rotate(holes_x, holes_y, -angle)    

        low_x, low_y = GeometryUtils.translate(low_x, low_y, -tx, -ty)
        low_x, low_y = GeometryUtils.rotate(low_x, low_y, -angle)    

        high_x, high_y = GeometryUtils.translate(high_x, high_y, -tx, -ty)
        high_x, high_y = GeometryUtils.rotate(high_x, high_y, -angle)    

        xo = torch.Tensor(orig_x).to(self.device)
        yo = torch.Tensor(orig_y).to(self.device)

        xh = torch.Tensor(holes_x).to(self.device)
        yh = torch.Tensor(holes_y).to(self.device)

        xlow = torch.Tensor(low_x).to(self.device)
        ylow = torch.Tensor(low_y).to(self.device)

        xhigh = torch.Tensor(high_x).to(self.device)
        yhigh = torch.Tensor(high_y).to(self.device)

        model = CurveModel(self.device, xo, yo)
        optim = torch.optim.Adam(model.parameters(), lr=0.001)

        loss_history = [] 
        path_history = [] 
        
        n_iters = int(self.iterations) 
        for i in range(0, n_iters):
            if i == 100:
                for g in optim.param_groups: g['lr'] = 0.01
            if i == 200:
                for g in optim.param_groups: g['lr'] = 0.03

            predictions = model.forward(xo)
            loss = self.my_loss(xo, predictions, yo, xh, yh, xlow, ylow, xhigh, yhigh, safety_radius)
            loss.backward()
            optim.step()
            optim.zero_grad()
            
            loss_history.append(loss.item()) 

            if i % 50 == 0:
                curr_y = predictions.detach().cpu().numpy()
                cx, cy = GeometryUtils.rotate(orig_x, curr_y, angle)
                cx, cy = GeometryUtils.translate(cx, cy, tx, ty)
                path_history.append((cx, cy))

            if i % 50 == 0:
                if torch.isnan(predictions).any():
                    print("NAN IN OPTIMIZATION")
        
        curve_x = orig_x
        curve_y = predictions.detach().cpu().numpy()

        curve_x, curve_y = GeometryUtils.rotate(curve_x, curve_y, angle)
        curve_x, curve_y = GeometryUtils.translate(curve_x, curve_y, tx, ty)

        return curve_x, curve_y, loss_history, path_history

    def fit_all_streets(self, streets : gpd.GeoDataFrame, holes : gpd.GeoDataFrame, geofence : gpd.GeoDataFrame, obstacles : gpd.GeoDataFrame, high_obstacles : gpd.GeoDataFrame, fit_twice : bool):
        all_holes_x = holes['x'].to_list()
        all_holes_y = holes['y'].to_list()

        all_geof_x, all_geof_y = GeometryUtils.sample_polygon(geofence.geometry.iloc[0])
        low_x = []
        low_y = []

        if obstacles is not None:
            for row in obstacles.geometry:
                cx, cy = GeometryUtils.sample_polygon(row)
                low_x += cx
                low_y += cy

        if high_obstacles is not None:
            for row in high_obstacles.geometry:
                cx, cy = GeometryUtils.sample_polygon(row)
                all_geof_x += cx
                all_geof_y += cy
                
        safety_radius = self.hole_buffer + self.fp_buffer + self.fp_side + self.safety_margin

        total_streets = len(streets)
        last_losses = [] 
        last_path_history = []
        
        for index, row in streets.iterrows():
            street = GeometryUtils.segmentize_line(row['geometry'], max_segment_length=1.0)
            orig_x = []
            orig_y = []
            for xy in list(street.coords):
                orig_x.append(xy[0])
                orig_y.append(xy[1])
            
            holes_x, holes_y = self.select_near_holes(orig_x, orig_y, all_holes_x, all_holes_y, safety_radius + 5.0)
            geof_x, geof_y = self.select_near_geofence_points(orig_x, orig_y, all_geof_x, all_geof_y, 11.0)

            curve_x, curve_y, losses, path_hist = self.fit_street(orig_x, orig_y, holes_x, holes_y, low_x, low_y, geof_x, geof_y, safety_radius=safety_radius)
            last_losses = losses
            last_path_history = path_hist

            if fit_twice:
                curve_x_filtered, curve_y_filtered = self.del_colliding(curve_x, curve_y, holes_x, holes_y, low_x, low_y, geof_x, geof_y)
                
                if len(curve_x_filtered) < 5:
                    print(f"Warning: del_colliding removed too many points ({len(curve_x_filtered)} left). Skipping second fit.")
                else:
                    curve_x, curve_y = curve_x_filtered, curve_y_filtered
                    curve_x, curve_y, losses, path_hist = self.fit_street(curve_x, curve_y, holes_x, holes_y, low_x, low_y, geof_x, geof_y, safety_radius=safety_radius)
                    last_losses = losses
                    last_path_history = path_hist

            xy = []
            for i in range(len(curve_x)):
                xy.append( (curve_x[i], curve_y[i]) )
            
            streets.loc[index, 'geometry'] = shapely.LineString(xy)
            streets.loc[index, 'geometry'] = shapely.simplify(streets.loc[index, 'geometry'], 0.1)

        return streets, last_losses, last_path_history

## 4. Creacion de entorno

Se crea un escenario simple para probar el algoritmo:
1.  Una calle recta.
2.  Unos obstáculos (agujeros).
3.  Una geocerca (límites del área).

In [3]:

# Create a straight street
street_coords = [(0, 0), (100, 0)]
street_geom = LineString(street_coords)
streets = gpd.GeoDataFrame({'geometry': [street_geom]})

# Create a geofence (large box)
geofence_geom = Polygon([(-10, -20), (110, -20), (110, 20), (-10, 20)])
geofence = gpd.GeoDataFrame({'geometry': [geofence_geom]})

# Helper to plot footprints using a temporary optimizer instance for geometry generation
def plot_footprints(street_geom, ax, optimizer, step=5):
    coords = list(street_geom.coords)
    for i in range(0, len(coords)-1, step):
        p1 = coords[i]
        p2 = coords[i+1]
        angle = math.atan2(p2[1] - p1[1], p2[0] - p1[0])
        
        # 1. Plot Real Robot (Buffer = 0)
        real_footprint = optimizer.gen_footprint_obstacle(p1[0], p1[1], angle, buffer_val=0.0)
        x, y = real_footprint.exterior.xy
        ax.plot(x, y, color='purple', alpha=0.1, linewidth=1.0) # Faint solid line
        
        # 2. Plot Safety Buffer (if exists)
        if optimizer.fp_buffer > 0:
            buffered_footprint = optimizer.gen_footprint_obstacle(p1[0], p1[1], angle) 
            x_buf, y_buf = buffered_footprint.exterior.xy
            ax.plot(x_buf, y_buf, color='purple', alpha=0.05, linestyle='--') # Very faint dashed line


## 5. Simulación Interactiva

Usa los controles deslizantes (sliders) para experimentar:

*   **Hole Positions**: Mueve los obstáculos de lugar.
*   **Hole Buffer**: Cambia el radio de seguridad alrededor de los agujeros.
*   **Footprint Dimensions**: Ajusta el tamaño del vehículo (Frente, Atrás, Lado).
*   **Footprint Buffer**: Margen de seguridad extra alrededor del vehículo.
*   **Path Position**: Mueve el "robot fantasma" a lo largo de la ruta generada para ver si cabe.
*   **Optimization Tuning**: Ajusta los pesos del algoritmo (qué tanto le importa la suavidad vs. no chocar).

**Haz clic en 'Optimize Route'** para ejecutar el algoritmo con tu configuración.

In [4]:

# Sliders for Holes
h1_slider = widgets.FloatSlider(min=-15, max=15, step=0.5, value=5, description='Hole 1 Y')
h2_slider = widgets.FloatSlider(min=-15, max=15, step=0.5, value=-5, description='Hole 2 Y')
h3_slider = widgets.FloatSlider(min=-15, max=15, step=0.5, value=5, description='Hole 3 Y')

# Sliders for Parameters
hole_buffer_slider = widgets.FloatSlider(min=1.0, max=15.0, step=0.5, value=7.0, description='Hole Buffer')
fp_front_slider = widgets.FloatSlider(min=1.0, max=10.0, step=0.5, value=3.5, description='FP Front')
fp_back_slider = widgets.FloatSlider(min=-10.0, max=-1.0, step=0.5, value=-3.5, description='FP Back')
fp_side_slider = widgets.FloatSlider(min=0.5, max=5.0, step=0.1, value=1.8, description='FP Side')
fp_buffer_slider = widgets.FloatSlider(min=0.0, max=2.0, step=0.1, value=1.0, description='FP Buffer')

# Slider for Path Traversal
path_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.0, description='Path Pos')

# Sliders for Optimization Tuning
repulsion_slider = widgets.FloatSlider(min=10.0, max=1000.0, step=10.0, value=500.0, description='Repulsion W')
fidelity_slider = widgets.FloatLogSlider(value=0.0001, base=10, min=-5, max=-1, step=0.1, description='Fidelity W')
safety_margin_slider = widgets.FloatSlider(min=0.0, max=2.0, step=0.1, value=0.5, description='Safety Margin')
soft_limit_slider = widgets.FloatSlider(min=1.0, max=10.0, step=0.5, value=5.0, description='Soft Limit')
iterations_slider = widgets.IntSlider(min=100, max=1000, step=50, value=500, description='Iterations')

run_btn = widgets.Button(description="Optimize Route", button_style='success', icon='check')
out = widgets.Output()
debug_out = widgets.Output() # Separate debug output

# Global variable to store the last fitted street and losses
current_fitted_streets = None
current_losses = []
current_path_history = []
current_optimizer = None # Store the optimizer used for the last run

def get_holes(h1, h2, h3):
    holes_coords = [(30, h1), (50, h2), (70, h3)]
    holes_geoms = [Point(x, y) for x, y in holes_coords]
    return gpd.GeoDataFrame({'geometry': holes_geoms, 'x': [p[0] for p in holes_coords], 'y': [p[1] for p in holes_coords]})

def plot_current_state(change=None):
    global current_fitted_streets, current_losses, current_path_history, current_optimizer
    
    # Create a temporary optimizer if none exists (for initial plot or slider updates before run)
    # This ensures we visualize the current slider settings for footprints
    if current_optimizer is None:
        temp_optimizer = PathOptimizer(
            hole_buffer=hole_buffer_slider.value,
            fp_front=fp_front_slider.value,
            fp_back=fp_back_slider.value,
            fp_side=fp_side_slider.value,
            fp_buffer=fp_buffer_slider.value
        )
    else:
        # If we have a previous run, we might want to use its settings OR update to current sliders
        # For interactive visualization, it's usually better to reflect current sliders
        temp_optimizer = PathOptimizer(
            hole_buffer=hole_buffer_slider.value,
            fp_front=fp_front_slider.value,
            fp_back=fp_back_slider.value,
            fp_side=fp_side_slider.value,
            fp_buffer=fp_buffer_slider.value
        )

    with out:
        clear_output(wait=True)
        # Get current hole positions
        holes = get_holes(h1_slider.value, h2_slider.value, h3_slider.value)
        
        # Create subplots: 3 rows, 1 column (Route, Evolution, Loss)
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 18), gridspec_kw={'height_ratios': [2, 2, 1]})
        
        # --- Plot 1: Interactive Route ---
        streets.plot(ax=ax1, color='blue', linestyle='--', label='Original Street')
        
        if current_fitted_streets is not None:
             current_fitted_streets.plot(ax=ax1, color='orange', linewidth=2, alpha=0.5, label='Last Fitted Street')
             plot_footprints(current_fitted_streets.geometry[0], ax1, temp_optimizer, step=5)
             
             # Dynamic Footprint
             line = current_fitted_streets.geometry[0]
             dist = line.length * path_slider.value
             point = line.interpolate(dist)
             delta = 0.1
             p_next = line.interpolate(min(dist + delta, line.length))
             p_prev = line.interpolate(max(dist - delta, 0))
             angle = math.atan2(p_next.y - p_prev.y, p_next.x - p_prev.x)
             
             real_fp = temp_optimizer.gen_footprint_obstacle(point.x, point.y, angle, buffer_val=0.0)
             ax1.plot(*real_fp.exterior.xy, color='black', linewidth=2, label='Test Robot')
             
             if temp_optimizer.fp_buffer > 0:
                 buf_fp = temp_optimizer.gen_footprint_obstacle(point.x, point.y, angle)
                 ax1.plot(*buf_fp.exterior.xy, color='black', linestyle='--', linewidth=1.5, label='Test Buffer')
        
        holes.plot(ax=ax1, color='red', markersize=100, label='Holes')
        holes.buffer(temp_optimizer.hole_buffer).plot(ax=ax1, color='red', alpha=0.2, label='Buffer')
        geofence.boundary.plot(ax=ax1, color='green', label='Geofence')
        ax1.set_title("Interactive Route Inspection")
        ax1.set_ylim(-20, 20)
        ax1.set_xlim(-10, 110)
        ax1.legend(loc='upper right')
        ax1.grid(True, alpha=0.3)
        
        # --- Plot 2: Evolution by Epochs ---
        streets.plot(ax=ax2, color='blue', linestyle='--', alpha=0.3)
        holes.plot(ax=ax2, color='red', markersize=100)
        holes.buffer(temp_optimizer.hole_buffer).plot(ax=ax2, color='red', alpha=0.1)
        geofence.boundary.plot(ax=ax2, color='green', alpha=0.3)
        
        if current_path_history:
            num_paths = len(current_path_history)
            for i, (px, py) in enumerate(current_path_history):
                # Gradient color: Light Orange -> Dark Orange
                alpha = 0.2 + 0.8 * (i / num_paths)
                color = (1.0, 0.5 * (1 - i/num_paths), 0.0) # Orange gradient
                ax2.plot(px, py, color=color, alpha=alpha, linewidth=1.5, label=f'Epoch {i*50}' if i == 0 or i == num_paths-1 else "")
            
            ax2.set_title(f"Path Evolution (Epochs 0 to {len(current_path_history)*50})")
        else:
            ax2.text(0.5, 0.5, "No history available", ha='center', va='center')
            ax2.set_title("Path Evolution")
            
        ax2.set_ylim(-20, 20)
        ax2.set_xlim(-10, 110)
        ax2.grid(True, alpha=0.3)

        # --- Plot 3: Loss ---
        if current_losses:
            ax3.plot(current_losses, color='blue', linewidth=1.5)
            ax3.set_title("Optimization Loss")
            ax3.set_xlabel("Iteration")
            ax3.set_ylabel("Loss")
            ax3.grid(True, alpha=0.3)
            ax3.set_yscale('log')
        else:
            ax3.text(0.5, 0.5, "No optimization run yet", ha='center', va='center')
        
        plt.tight_layout()
        plt.show()

def run_optimization(b):
    try:
        # Immediate debug print
        with debug_out:
            print("Button clicked! Starting optimization...")
    
        global current_fitted_streets, current_losses, current_path_history, current_optimizer
        
        # Instantiate Optimizer with current slider values
        optimizer = PathOptimizer(
            hole_buffer=hole_buffer_slider.value,
            fp_front=fp_front_slider.value,
            fp_back=fp_back_slider.value,
            fp_side=fp_side_slider.value,
            fp_buffer=fp_buffer_slider.value,
            repulsion_weight=repulsion_slider.value,
            fidelity_weight=fidelity_slider.value,
            safety_margin=safety_margin_slider.value,
            soft_limit_range=soft_limit_slider.value,
            iterations=iterations_slider.value
        )
        current_optimizer = optimizer
        
        with out:
            clear_output(wait=True)
            print("Optimizing... please wait.")
            
            holes = get_holes(h1_slider.value, h2_slider.value, h3_slider.value)
            
            # Run optimization using the class method
            fitted_streets, losses, path_hist = optimizer.fit_all_streets(streets.copy(), holes, geofence, None, None, fit_twice=True)
            current_fitted_streets = fitted_streets 
            current_losses = losses
            current_path_history = path_hist
            
            clear_output(wait=True) # Clear the "Optimizing..." message
            
            # Re-use plot_current_state to show results
            plot_current_state()
            
            with debug_out:
                print("Optimization finished successfully.")
            
    except Exception as e:
        with debug_out:
            print("Optimization failed!")
            print(traceback.format_exc())

# Link sliders to preview
h1_slider.observe(plot_current_state, names='value')
h2_slider.observe(plot_current_state, names='value')
h3_slider.observe(plot_current_state, names='value')
hole_buffer_slider.observe(plot_current_state, names='value')
fp_front_slider.observe(plot_current_state, names='value')
fp_back_slider.observe(plot_current_state, names='value')
fp_side_slider.observe(plot_current_state, names='value')
fp_buffer_slider.observe(plot_current_state, names='value')
path_slider.observe(plot_current_state, names='value')

# Layout
hole_controls = widgets.VBox([h1_slider, h2_slider, h3_slider])
param_controls = widgets.VBox([hole_buffer_slider, fp_front_slider, fp_back_slider, fp_side_slider, fp_buffer_slider])
tuning_controls = widgets.VBox([repulsion_slider, fidelity_slider, safety_margin_slider, soft_limit_slider, iterations_slider])
path_control = widgets.VBox([path_slider])

ui = widgets.HBox([hole_controls, param_controls, tuning_controls, path_control])

# Initial display
run_btn.on_click(run_optimization)
display(widgets.VBox([ui, run_btn, out, debug_out]))

# Run initial optimization so the user sees something immediately
run_optimization(None)


VBox(children=(HBox(children=(VBox(children=(FloatSlider(value=5.0, description='Hole 1 Y', max=15.0, min=-15.…