# Comparing Segmentation and Perturbation Methods

This notebook compares different segmentation and perturbation methods for trajectory explanation.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pactus import Dataset
from pactus.models import LSTMModel

# Import from the traj-xai package
from traj_xai import (
    rdp_segmentation,
    mdl_segmentation,
    sliding_window_segmentation,
    random_segmentation,
    gaussian_perturbation,
    scaling_perturbation,
    rotation_perturbation,
    run_experiments,
)

## Load and Prepare Data


In [None]:
# Set a random seed for reproducibility
SEED = 0

# Load the UCI Movement Libras dataset (smaller dataset for demonstration)
dataset = Dataset.uci_movement_libras()
print(f"Dataset loaded: {len(dataset.trajs)} trajectories")

# Split data into train and test subsets
train, test = dataset.split(0.8, random_state=SEED)
print(f"Train set: {len(train.trajs)} trajectories")
print(f"Test set: {len(test.trajs)} trajectories")

# Take a small subset for quicker demonstration
small_test = test.sample(5, random_state=SEED)
print(f"Small test set: {len(small_test.trajs)} trajectories")

## Train a Black Box Model


In [None]:
# Build and train the model
model = LSTMModel(random_state=SEED)

# Train the model on the train dataset (with fewer epochs for demo)
model.train(train, dataset, epochs=5, batch_size=64)

# Evaluate the model on the test dataset
evaluation = model.evaluate(test)
evaluation.show()

## Compare Different Segmentation Methods


In [None]:
# Get a sample trajectory
sample_idx = 0
sample_traj = small_test.trajs[sample_idx]
sample_label = small_test.labels[sample_idx]

# Extract points
points = sample_traj.r
x = [p[0] for p in points]
y = [p[1] for p in points]

# Apply different segmentation methods
rdp_segments = rdp_segmentation(points, epsilon=0.01)
mdl_segments = mdl_segmentation(points, epsilon=0.8)
sliding_segments = sliding_window_segmentation(points, step=5, percentage=10)
random_segments = random_segmentation(points, num_segments=5)

# Prepare subplots for visualization
fig, axs = plt.subplots(2, 2, figsize=(15, 10))
axs = axs.flatten()


# Helper function to plot segments
def plot_segments(ax, segments, method_name):
    # Plot original trajectory
    ax.plot(x, y, "k-", alpha=0.3, label="Original")

    # Plot segments
    colors = plt.cm.rainbow(np.linspace(0, 1, len(segments)))
    for i, segment in enumerate(segments):
        segment_x = [p[0] for p in segment]
        segment_y = [p[1] for p in segment]
        ax.plot(segment_x, segment_y, "-", color=colors[i], linewidth=2)

    ax.set_title(f"{method_name} (Segments: {len(segments)})")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.grid(True)


# Plot each segmentation method
plot_segments(axs[0], rdp_segments, "RDP Segmentation")
plot_segments(axs[1], mdl_segments, "MDL Segmentation")
plot_segments(axs[2], sliding_segments, "Sliding Window")
plot_segments(axs[3], random_segments, "Random Segmentation")

plt.tight_layout()
plt.show()

## Compare Different Perturbation Methods


In [None]:
# Choose one segmentation method
segments = rdp_segments

# Choose one segment to perturb
segment_to_perturb_idx = 0
segment_to_perturb = segments[segment_to_perturb_idx]

# Apply different perturbation methods
gaussian_perturbed = gaussian_perturbation(segment_to_perturb)
scaling_perturbed = scaling_perturbation(segment_to_perturb)
rotation_perturbed = rotation_perturbation(segment_to_perturb)

# Prepare subplots for visualization
fig, axs = plt.subplots(1, 3, figsize=(18, 5))


# Helper function to plot perturbed segment
def plot_perturbation(ax, perturbed_segment, method_name):
    # Plot original trajectory
    ax.plot(x, y, "k-", alpha=0.3, label="Original")

    # Plot original segment
    segment_x = [p[0] for p in segment_to_perturb]
    segment_y = [p[1] for p in segment_to_perturb]
    ax.plot(segment_x, segment_y, "b-", linewidth=2, label="Original Segment")

    # Plot perturbed segment
    perturbed_x = [p[0] for p in perturbed_segment]
    perturbed_y = [p[1] for p in perturbed_segment]
    ax.plot(perturbed_x, perturbed_y, "r-", linewidth=2, label="Perturbed")

    ax.set_title(method_name)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.legend()
    ax.grid(True)


# Plot each perturbation method
plot_perturbation(axs[0], gaussian_perturbed, "Gaussian Perturbation")
plot_perturbation(axs[1], scaling_perturbed, "Scaling Perturbation")
plot_perturbation(axs[2], rotation_perturbed, "Rotation Perturbation")

plt.tight_layout()
plt.show()

## Run Experiments with Different Methods


In [None]:
# Define the segmentation and perturbation functions to compare
segment_funcs = [rdp_segmentation, mdl_segmentation]
perturbation_funcs = [gaussian_perturbation, rotation_perturbation]

# Create a log directory
import os

log_dir = "comparison_results"
os.makedirs(log_dir, exist_ok=True)

# Run experiments on a small subset for demonstration
print("Running experiments with different methods...")
run_experiments(small_test, segment_funcs, perturbation_funcs, model, log_dir=log_dir)

## Analyze Results


In [None]:
# Load and analyze results
import pandas as pd
import glob

# Find all result files
result_files = glob.glob(os.path.join(log_dir, "*.csv"))
print(f"Found {len(result_files)} result files")

# Load results into DataFrames
results_data = {}
for file_path in result_files:
    method_name = os.path.basename(file_path).replace("_results.csv", "")
    results_data[method_name] = pd.read_csv(file_path)

# Calculate average precision score and change percentage for each method
summary = []
for method, df in results_data.items():
    avg_precision = df["precision_score"].mean()
    change_pct = df["change"].mean() * 100
    summary.append(
        {
            "Method": method,
            "Avg Precision": avg_precision,
            "Change %": change_pct,
            "Count": len(df),
        }
    )

# Display summary
summary_df = pd.DataFrame(summary)
display(summary_df)

# Plot results
plt.figure(figsize=(12, 6))
x = np.arange(len(summary))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(
    x - width / 2,
    [item["Avg Precision"] for item in summary],
    width,
    label="Avg Precision",
)
ax.bar(
    x + width / 2, [item["Change %"] / 100 for item in summary], width, label="Change %"
)

ax.set_ylabel("Score")
ax.set_title("Comparison of Different Methods")
ax.set_xticks(x)
ax.set_xticklabels([item["Method"] for item in summary])
ax.legend()

plt.tight_layout()
plt.show()