# Traj-XAI Basic Example

This notebook demonstrates the basic usage of the Traj-XAI package for explaining trajectory classifications.


## Setup

First, let's import the necessary libraries:


In [None]:
# Install required packages if needed
# !pip install rdp fastdtw pactus

In [None]:
import numpy as np
from pactus import Dataset
from pactus.models import LSTMModel
import matplotlib.pyplot as plt

# Import from our traj-xai package
from traj_xai import rdp_segmentation, gaussian_perturbation, TrajectoryManipulator

## Load and Prepare Data


In [None]:
# Set a random seed for reproducibility
SEED = 0

# Load the UCI Movement Libras dataset
dataset = Dataset.uci_movement_libras()
print(f"Dataset loaded: {len(dataset.trajs)} trajectories")

# Split data into train and test subsets
train, test = dataset.split(0.8, random_state=SEED)
print(f"Train set: {len(train.trajs)} trajectories")
print(f"Test set: {len(test.trajs)} trajectories")

## Train a Black Box Model


In [None]:
# Build and train the model
model = LSTMModel(random_state=SEED)

# Train the model on the train dataset (with fewer epochs for demo)
model.train(train, dataset, epochs=5, batch_size=64)

# Evaluate the model on the test dataset
evaluation = model.evaluate(test)
evaluation.show()

## Visualize a Sample Trajectory


In [None]:
# Get a sample trajectory
sample_idx = 0
sample_traj = test.trajs[sample_idx]
sample_label = test.labels[sample_idx]

# Extract coordinates for plotting
points = sample_traj.r
x = [p[0] for p in points]
y = [p[1] for p in points]

# Plot the trajectory
plt.figure(figsize=(10, 8))
plt.plot(x, y, "b-", label="Trajectory")
plt.scatter(x[0], y[0], c="green", s=100, label="Start")
plt.scatter(x[-1], y[-1], c="red", s=100, label="End")
plt.title(f"Trajectory (Label: {sample_label})")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.show()

## Segment the Trajectory


In [None]:
# Apply RDP segmentation
segments = rdp_segmentation(points, epsilon=0.01)
print(f"Number of segments: {len(segments)}")

# Plot the original trajectory and the segments
plt.figure(figsize=(12, 8))

# Plot original trajectory
plt.plot(x, y, "k-", alpha=0.3, label="Original")

# Plot each segment with a different color
colors = plt.cm.rainbow(np.linspace(0, 1, len(segments)))
for i, segment in enumerate(segments):
    segment_x = [p[0] for p in segment]
    segment_y = [p[1] for p in segment]
    plt.plot(
        segment_x, segment_y, "-", color=colors[i], linewidth=2, label=f"Segment {i+1}"
    )

plt.title("Trajectory Segments using RDP")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.grid(True)
plt.tight_layout()
plt.show()

## Perturb One Segment


In [None]:
# Apply Gaussian perturbation to one segment
segment_to_perturb = 1  # Choose which segment to perturb
perturbed_segment = gaussian_perturbation(segments[segment_to_perturb])

# Visualize the perturbation
plt.figure(figsize=(12, 8))

# Plot original trajectory
plt.plot(x, y, "k-", alpha=0.3, label="Original")

# Create a modified trajectory by replacing the perturbed segment
modified_segments = segments.copy()
modified_segments[segment_to_perturb] = perturbed_segment
modified_trajectory = []
for segment in modified_segments:
    modified_trajectory.extend(segment)

# Plot modified trajectory
modified_x = [p[0] for p in modified_trajectory]
modified_y = [p[1] for p in modified_trajectory]
plt.plot(modified_x, modified_y, "r-", label="Modified Trajectory")

# Highlight the perturbed segment
perturbed_x = [p[0] for p in perturbed_segment]
perturbed_y = [p[1] for p in perturbed_segment]
plt.plot(
    perturbed_x,
    perturbed_y,
    "g-",
    linewidth=2,
    label=f"Perturbed Segment {segment_to_perturb+1}",
)

plt.title("Trajectory with Perturbed Segment")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.show()

## Generate Explanation


In [None]:
# Create a TrajectoryManipulator for the sample trajectory
trajectory_explainer = TrajectoryManipulator(
    points, rdp_segmentation, gaussian_perturbation, model
)

# Get segments
segments = trajectory_explainer.get_segment()
print(f"Number of segments: {len(segments)}")

# Generate explanation
coef = trajectory_explainer.explain()
print(f"\nExplanation coefficients: {coef}")

if coef is not None:
    # Get the absolute values of coefficients for importance
    if len(coef.shape) > 1 and coef.shape[0] > 1:
        # For multi-class, take the row corresponding to the predicted class
        y = trajectory_explainer.get_Y()[0]
        class_index = np.where(trajectory_explainer.classes_ == y)[0][0]
        importances = np.abs(coef[class_index])
    else:
        importances = np.abs(coef[0])

    # Visualize segment importance
    plt.figure(figsize=(12, 8))

    # Plot the original trajectory
    plt.plot(x, y, "k-", alpha=0.3)

    # Plot each segment with color intensity based on importance
    normalized_importances = importances / importances.max()

    for i, segment in enumerate(segments):
        segment_x = [p[0] for p in segment]
        segment_y = [p[1] for p in segment]

        # Use red for positive influence, stronger color for higher importance
        color_intensity = normalized_importances[i]
        plt.plot(
            segment_x,
            segment_y,
            "-",
            color=(1.0, 0.0, 0.0, color_intensity),
            linewidth=2 + 3 * color_intensity,
        )

        # Add a text label for the segment importance
        mid_point = len(segment) // 2
        plt.text(
            segment_x[mid_point],
            segment_y[mid_point],
            f"{importances[i]:.2f}",
            fontweight="bold",
            ha="center",
            va="bottom",
        )

    plt.title("Segment Importance for Classification")
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.colorbar(plt.cm.ScalarMappable(cmap="Reds"), label="Importance")
    plt.grid(True)
    plt.show()

## Evaluate the Effect of Perturbations


In [None]:
# Get predictions for evaluation trajectories
y_eval_sorted = trajectory_explainer.get_Y_eval_sorted()
print("Predictions for evaluation trajectories (sorted by importance):")
for i, pred in enumerate(y_eval_sorted):
    print(f"Segment {i+1}: {pred}")

# Get the original prediction
original_pred = trajectory_explainer.get_Y()
print(f"\nOriginal trajectory prediction: {original_pred}")

# Check if any segments change the classification when perturbed
changed_segments = []
for i, pred in enumerate(y_eval_sorted):
    if pred != original_pred:
        changed_segments.append(i)

print(f"\nSegments that changed classification when perturbed: {changed_segments}")