# Advanced Trajectory Tools Demo

This notebook demonstrates the usage of advanced trajectory generation, manipulation, and utility functions available in the `trajgen` module. We will cover:

1.  Generating spiral and radial trajectories.
2.  Constraining trajectories based on hardware limits.
3.  Computing density compensation weights using different methods.
4.  Visualizing trajectories and their properties (including Voronoi diagrams).
5.  Performing a simple image reconstruction from non-Cartesian k-space data.

## Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

# Add the parent directory to the Python path to allow importing trajgen
# This assumes the notebook is in 'examples' and 'trajgen.py' is in the parent directory
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from trajgen import (
    Trajectory,
    normalize_density_weights, # Though not directly called, good to know it's there
    compute_density_compensation,
    create_periodic_points, 
    compute_voronoi_density,
    generate_spiral_trajectory,
    generate_radial_trajectory,
    constrain_trajectory,
    reconstruct_image,
    display_trajectory
)

%matplotlib inline 
# Use %matplotlib widget for interactive plots if your environment supports it

## Section 1: Trajectory Generation

### Spiral Trajectory

In [None]:
num_arms_spiral = 16
num_samples_per_arm_spiral = 1024
fov_spiral = 0.22  # meters
dt_spiral = 4e-6 # seconds

spiral_traj = generate_spiral_trajectory(
    num_arms=num_arms_spiral,
    num_samples_per_arm=num_samples_per_arm_spiral,
    fov_m=fov_spiral,
    dt_seconds=dt_spiral,
    name="DemoSpiral"
)

print("--- Spiral Trajectory Summary ---")
spiral_traj.summary() # Assuming the Trajectory class has a summary method
print("\n--- Metadata --- ")
for key, value in spiral_traj.metadata.items():
    if key == 'generator_params':
        print(f"Generator Params:")
        for gp_key, gp_value in value.items():
            print(f"  {gp_key}: {gp_value}")
    else:
        print(f"{key}: {value}")

plt.figure(figsize=(7,7))
display_trajectory(spiral_traj, plot_type="2D", max_total_points=2000) # Limit points for clarity
plt.title("Generated Spiral Trajectory")
plt.show()

### Radial Trajectory

In [None]:
num_spokes_radial = 64
num_samples_per_spoke_radial = 256
fov_radial = 0.24 # meters
dt_radial = 4e-6 # seconds

radial_traj = generate_radial_trajectory(
    num_spokes=num_spokes_radial,
    num_samples_per_spoke=num_samples_per_spoke_radial,
    fov_m=fov_radial,
    dt_seconds=dt_radial,
    use_golden_angle=True,
    name="DemoRadialGA"
)

print("--- Radial Trajectory Summary ---")
radial_traj.summary()

plt.figure(figsize=(7,7))
display_trajectory(radial_traj, plot_type="2D", max_total_points=5000, plot_style='-') # Show more points for radial
plt.title("Generated Golden Angle Radial Trajectory")
plt.show()

## Section 2: Trajectory Constraints

In [None]:
# Generate a fast spiral that might violate constraints
fast_spiral_traj = generate_spiral_trajectory(
    num_arms=8,
    num_samples_per_arm=512,
    fov_m=0.2,
    max_k_rad_per_m=np.pi / 0.002, # High k_max (2mm resolution)
    num_revolutions=20, # Many revolutions
    dt_seconds=4e-6, 
    name="FastSpiral"
)

max_gradient = 0.04  # T/m
max_slew_rate = 150  # T/m/s

print(f"Original trajectory max gradient: {fast_spiral_traj.get_max_grad_Tm():.4f} T/m")
print(f"Original trajectory max slew rate: {fast_spiral_traj.get_max_slew_Tm_per_s():.2f} T/m/s")

constrained_spiral_traj = constrain_trajectory(
    fast_spiral_traj,
    max_gradient_Tm_per_m=max_gradient,
    max_slew_rate_Tm_per_m_per_s=max_slew_rate
)

print(f"\nConstrained trajectory max gradient: {constrained_spiral_traj.get_max_grad_Tm():.4f} T/m")
print(f"Constrained trajectory max slew rate: {constrained_spiral_traj.get_max_slew_Tm_per_s():.2f} T/m/s")

fig, axes = plt.subplots(1, 2, figsize=(14, 6))
display_trajectory(fast_spiral_traj, plot_type="2D", ax=axes[0], title="Original Fast Spiral", max_total_points=2000)
display_trajectory(constrained_spiral_traj, plot_type="2D", ax=axes[1], title="Constrained Spiral", max_total_points=2000)
plt.tight_layout()
plt.show()

print("\n--- Original Fast Spiral Summary ---")
fast_spiral_traj.summary()
print("\n--- Constrained Spiral Summary ---")
constrained_spiral_traj.summary()

## Section 3: Density Compensation

### Voronoi Method (using `compute_density_compensation`)

In [None]:
# Using the radial trajectory from earlier
# compute_density_compensation expects trajectory k-space points, typically (N, D) or complex (A,S)
k_points_radial = radial_traj.kspace_points_rad_per_m # This is (D, N)
if k_points_radial.shape[0] == 2: # Ensure it's (N,D) for the function if it's 2D
    k_points_radial_for_comp = k_points_radial.T
else:
    k_points_radial_for_comp = k_points_radial # Assuming already N,D or function handles it

weights_voronoi = compute_density_compensation(k_points_radial_for_comp, method="voronoi")

plt.figure(figsize=(8,7))
sc = plt.scatter(k_points_radial[0,:], k_points_radial[1,:], c=weights_voronoi, cmap='viridis', s=5)
plt.colorbar(sc, label='Voronoi Weights')
plt.title('Radial Trajectory with Voronoi Density Compensation')
plt.xlabel('Kx (rad/m)')
plt.ylabel('Ky (rad/m)')
plt.axis('equal')
plt.show()

print(f"Voronoi weights: min={weights_voronoi.min():.2e}, max={weights_voronoi.max():.2e}, sum={weights_voronoi.sum():.2f}")

### Pipe Method (using `compute_density_compensation`)

In [None]:
# Using the spiral trajectory from earlier
k_points_spiral = spiral_traj.kspace_points_rad_per_m # (D, N)
# For compute_density_compensation, we can pass complex k-space data directly 
# if the trajectory was originally generated that way, or convert to (N,D) real.
# Here, spiral_traj.kspace_points_rad_per_m is (2, N_total). Let's convert to complex (N_total,)
k_complex_spiral = k_points_spiral[0,:] + 1j * k_points_spiral[1,:]

weights_pipe = compute_density_compensation(k_complex_spiral, method="pipe", dt_seconds=spiral_traj.dt_seconds)

plt.figure(figsize=(8,7))
sc = plt.scatter(k_points_spiral[0,:], k_points_spiral[1,:], c=weights_pipe, cmap='magma', s=5)
plt.colorbar(sc, label='Pipe Weights')
plt.title('Spiral Trajectory with Pipe Density Compensation')
plt.xlabel('Kx (rad/m)')
plt.ylabel('Ky (rad/m)')
plt.axis('equal')
plt.show()

print(f"Pipe weights: min={weights_pipe.min():.2e}, max={weights_pipe.max():.2e}, sum={weights_pipe.sum():.2f}")

### Advanced Voronoi Density (using `compute_voronoi_density`)

In [None]:
# Create a simple 2D trajectory with fewer points for easier visualization
simple_points = np.array([[0,0], [0.5,0.1], [-0.2,0.4], [0.3,-0.3], [-0.4,-0.2]]) # (N,D)

# compute_voronoi_density expects points normalized to [-0.5, 0.5] for 'periodic'
# The function also has internal normalization, but good to be aware.
min_vals = np.min(simple_points, axis=0)
max_vals = np.max(simple_points, axis=0)
range_vals = max_vals - min_vals
range_vals[range_vals == 0] = 1 # Avoid division by zero for single points
norm_traj_pts = (simple_points - min_vals) / range_vals - 0.5
# Ensure truly centered if original data was e.g. all positive
current_min = np.min(norm_traj_pts, axis=0)
current_max = np.max(norm_traj_pts, axis=0)
current_range = current_max - current_min
current_range[current_range==0] = 1
norm_traj_pts = (norm_traj_pts - current_min) / current_range - 0.5


weights_periodic = compute_voronoi_density(norm_traj_pts, boundary_type="periodic")
weights_clipped = compute_voronoi_density(norm_traj_pts, boundary_type="clipped")

fig, axes = plt.subplots(1, 2, figsize=(14, 6))
sc_p = axes[0].scatter(norm_traj_pts[:,0], norm_traj_pts[:,1], c=weights_periodic, cmap='viridis', s=50)
fig.colorbar(sc_p, ax=axes[0], label='Periodic Weights')
axes[0].set_title('Periodic Voronoi Density')
axes[0].set_xlabel('Kx (normalized)')
axes[0].set_ylabel('Ky (normalized)')
axes[0].axis('equal')

sc_c = axes[1].scatter(norm_traj_pts[:,0], norm_traj_pts[:,1], c=weights_clipped, cmap='viridis', s=50)
fig.colorbar(sc_c, ax=axes[1], label='Clipped Weights')
axes[1].set_title('Clipped Voronoi Density')
axes[1].set_xlabel('Kx (normalized)')
axes[1].set_ylabel('Ky (normalized)')
axes[1].axis('equal')

plt.tight_layout()
plt.show()

print(f"Periodic weights: {weights_periodic}")
print(f"Clipped weights: {weights_clipped}")
print("Note: For 'clipped', points on the convex hull often get very small or zero cell areas before median replacement, leading to more uniform weights after fallback.")
print("For 'periodic', all points are treated as internal to a tiled space, yielding more varied cell sizes.")

### Using `Trajectory.calculate_voronoi_density()` and `plot_voronoi()`

In [None]:
# Using the radial trajectory. plot_voronoi calls calculate_voronoi_density internally.
fig_vor = plt.figure(figsize=(8,7))
ax_vor = fig_vor.add_subplot(111)
radial_traj.plot_voronoi(ax=ax_vor, title="Voronoi Diagram for Radial Trajectory", show_points=True, line_width=0.5)
plt.show()

## Section 4: Image Reconstruction

In [None]:
# Using the radial trajectory for reconstruction
num_pts_radial = radial_traj.get_num_points()
k_coords_radial = radial_traj.kspace_points_rad_per_m # (D, N)

# Simulate k-space data: higher signal at center, decaying outwards + noise
k_radii = np.sqrt(k_coords_radial[0,:]**2 + k_coords_radial[1,:]**2)
max_k_radius = np.max(k_radii) if np.max(k_radii) > 0 else 1.0
sim_kspace_data = np.exp(-k_radii / max_k_radius * 2) + \
                  0.05 * (np.random.randn(num_pts_radial) + 1j * np.random.randn(num_pts_radial))
sim_kspace_data[k_radii < max_k_radius*0.1] *= 2 # Boost center signal slightly more

grid_size = (64, 64)

print(f"Reconstructing image with grid size: {grid_size}")
image_recon = reconstruct_image(
    sim_kspace_data,
    radial_traj,
    grid_size,
    density_comp_method="voronoi", # or 'pipe' or None
    verbose=True
)

plt.figure(figsize=(7,7))
plt.imshow(image_recon, cmap='gray')
plt.title(f'Reconstructed Image ({grid_size[0]}x{grid_size[1]}) from Radial Traj.')
plt.colorbar()
plt.show()

## Section 5: 3D Trajectory Visualization (Example)

In [None]:
# Generate a simple 3D trajectory (e.g., Stack-of-Spirals or 3D Radial)
# For simplicity, let's adapt the radial generator for a basic 3D stack-of-stars
num_spokes_3d = 16
num_samples_per_spoke_3d = 32
num_slices_3d = 8
fov_3d = 0.2
k_max_3d = np.pi / (fov_3d / num_slices_3d) # Kz max for slice thickness

all_k_points_3d_list = []
for i_slice in range(num_slices_3d):
    kz_offset = -k_max_3d/2 + (i_slice + 0.5) * (k_max_3d / num_slices_3d)
    radial_2d_part = generate_radial_trajectory(
        num_spokes=num_spokes_3d // num_slices_3d, # Fewer spokes per slice
        num_samples_per_spoke=num_samples_per_spoke_3d,
        fov_m=fov_3d,
        use_golden_angle=False # Uniform for simplicity here
    )
    kxy_points = radial_2d_part.kspace_points_rad_per_m # (2, N_slice_pts)
    kz_points = np.full(kxy_points.shape[1], kz_offset)
    slice_k_points = np.vstack((kxy_points, kz_points)) # (3, N_slice_pts)
    all_k_points_3d_list.append(slice_k_points)

k_points_3d_combined = np.concatenate(all_k_points_3d_list, axis=1)
traj_3d_example = Trajectory("Demo3DStackOfStars", k_points_3d_combined, dt_seconds=4e-6)

print(f"Generated 3D trajectory with {traj_3d_example.get_num_points()} points.")

fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')
display_trajectory(traj_3d_example, plot_type="3D", ax=ax, max_total_points=2000) # Limit points
plt.title("3D Stack-of-Stars Trajectory Example")
plt.show()

## Conclusion

This notebook demonstrated key functionalities for generating, constraining, and utilizing k-space trajectories, including density compensation and basic image reconstruction. These tools provide a foundation for exploring various non-Cartesian MRI sequences and reconstruction techniques.