# Trajectory Class Features

This notebook demonstrates advanced features of the `Trajectory` class from the `trajgen` library, including handling dead times, export/import functionality, Voronoi density calculation, and direct data access.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # For 3D plotting if needed
from trajgen import KSpaceTrajectoryGenerator, Trajectory, COMMON_NUCLEI_GAMMA_HZ_PER_T
import os # For file operations

# Ensure plots appear inline in the notebook
%matplotlib inline

## 1. Creating a Sample Trajectory (2D Spiral)

We'll start by creating a simple 2D spiral trajectory to work with.

In [None]:
gen_spiral = KSpaceTrajectoryGenerator(
    traj_type='spiral',
    dim=2,
    fov=0.220,
    resolution=0.005,
    n_interleaves=4, # Multi-interleaf spiral
    turns=5,
    use_golden_angle=True
)

kx_s, ky_s, gx_s, gy_s, t_s = gen_spiral.generate()

# Combine interleaves
kspace_spiral_flat = np.stack([kx_s.ravel(), ky_s.ravel()])
gradients_spiral_flat = np.stack([gx_s.ravel(), gy_s.ravel()])

traj_sample = Trajectory(
    name='Sample Spiral for Features Demo',
    kspace_points_rad_per_m=kspace_spiral_flat,
    gradient_waveforms_Tm=gradients_spiral_flat,
    dt_seconds=gen_spiral.dt,
    metadata={'gamma_Hz_per_T': gen_spiral.gamma}
)

traj_sample.summary()

# Plot the sample trajectory
plt.figure(figsize=(6,6))
plt.plot(traj_sample.kspace_points_rad_per_m[0,:], traj_sample.kspace_points_rad_per_m[1,:], '.')
plt.title(traj_sample.name)
plt.xlabel('Kx (rad/m)')
plt.ylabel('Ky (rad/m)')
plt.axis('equal')
plt.show()

## 2. Dead Time Indication

The `Trajectory` class can account for dead times at the start and end of the acquisition.

In [None]:
traj_with_deadtime = Trajectory(
    name='Spiral with Dead Time',
    kspace_points_rad_per_m=kspace_spiral_flat, # Using same k-space data
    gradient_waveforms_Tm=gradients_spiral_flat,
    dt_seconds=gen_spiral.dt,
    metadata={'gamma_Hz_per_T': gen_spiral.gamma},
    dead_time_start_seconds=1e-3,  # 1 ms dead time at start
    dead_time_end_seconds=0.5e-3   # 0.5 ms dead time at end
)

traj_with_deadtime.summary()

Observe the `summary()` output. It now includes:
- Dead Time (Start)
- Dead Time (End)
- Total Dead Time
- The main "Duration" field is now "Total Duration (incl. deadtime)"
- Metadata also contains `dead_time_start_samples` and `dead_time_end_samples` if `dt_seconds` is available.

## 3. Exporting and Importing Trajectories

Trajectories can be saved to and loaded from `.npz` files.

In [None]:
export_filename = 'sample_trajectory_export.npz'

print(f"Exporting trajectory: {traj_with_deadtime.name} to {export_filename}")
traj_with_deadtime.export(export_filename)

print(f"\nImporting trajectory from {export_filename}")
traj_imported = Trajectory.import_from(export_filename)

print("\nSummary of the imported trajectory:")
traj_imported.summary()

# Clean up the created file
if os.path.exists(export_filename):
    os.remove(export_filename)
    print(f"\nCleaned up {export_filename}")

The summary of the imported trajectory should match the one with dead times, including all metadata.

## 4. Voronoi Density Calculation

The `Trajectory` class can calculate Voronoi cell sizes for each k-space point, providing a measure of local sampling density.

In [None]:
# Using the 'traj_sample' (4-interleaf 2D spiral) for this demonstration
print("Calculating Voronoi density for the sample 2D spiral...")
voronoi_cell_sizes = traj_sample.calculate_voronoi_density()

if voronoi_cell_sizes is not None:
    print(f"Calculation status: {traj_sample.metadata.get('voronoi_calculation_status')}")
    print(f"Number of cell sizes calculated: {len(voronoi_cell_sizes)}")
    # The full summary will also print these stats
    traj_sample.summary()
else:
    print(f"Voronoi calculation failed. Status: {traj_sample.metadata.get('voronoi_calculation_status')}")

### Voronoi Diagram Visualization

The `plot_voronoi()` method can be used to visualize the calculated Voronoi cells for 2D trajectories. This helps in understanding the local sampling density.

In [None]:
# Plotting k-space points colored by Voronoi cell sizes (for 2D)
if voronoi_cell_sizes is not None and traj_sample.get_num_dimensions() == 2:
    finite_cell_sizes = voronoi_cell_sizes[np.isfinite(voronoi_cell_sizes)]
    if len(finite_cell_sizes) > 0:
        # Use log scale for color mapping due to potentially large variations
        # Add a small epsilon to avoid log(0) if cell sizes can be zero (e.g. degenerate)
        log_cell_sizes = np.log10(voronoi_cell_sizes + 1e-9) 
        # Cap color range for better visualization if there are extreme outliers
        # For this example, we'll use percentiles for robust capping
        finite_log_sizes = log_cell_sizes[np.isfinite(log_cell_sizes)]
        vmin = np.percentile(finite_log_sizes, 5) if len(finite_log_sizes) > 0 else np.min(log_cell_sizes[np.isfinite(log_cell_sizes)])
        vmax = np.percentile(finite_log_sizes, 95) if len(finite_log_sizes) > 0 else np.max(log_cell_sizes[np.isfinite(log_cell_sizes)])

        plt.figure(figsize=(8, 6))
        scatter = plt.scatter(traj_sample.kspace_points_rad_per_m[0,:],
                              traj_sample.kspace_points_rad_per_m[1,:],
                              c=log_cell_sizes, cmap='viridis',
                              s=5, vmin=vmin, vmax=vmax)
        plt.colorbar(scatter, label='Log10(Voronoi Cell Area)')
        plt.title('K-Space Points Colored by Voronoi Cell Size')
        plt.xlabel('Kx (rad/m)')
        plt.ylabel('Ky (rad/m)')
        plt.axis('equal')
        plt.grid(True)
        plt.show()
    else:
        print("No finite Voronoi cell sizes to plot.")
elif traj_sample.get_num_dimensions() != 2:
    print("Voronoi plot is configured for 2D trajectories in this example.")
else:
    print("Voronoi data was not successfully calculated, skipping plot.")

In [None]:
# Demonstrate plot_voronoi on the traj_sample
if traj_sample.metadata.get('voronoi_calculation_status') == "Success":
    fig_voronoi = plt.figure(figsize=(10, 8))
    ax_voronoi = fig_voronoi.add_subplot(111)
    traj_sample.plot_voronoi(
        ax=ax_voronoi,
        show_points=True, 
        show_vertices=False, 
        color_by_area=True, 
        cmap='coolwarm',
        line_width=0.5,
        line_colors='k',
        point_size=2 # Example of passing kwarg for voronoi_plot_2d if it was used, or for manual plotting
    )
    plt.title("Voronoi Diagram for Sample Spiral (Colored by Area)")
    plt.show()
else:
    print("Skipping Voronoi plot as calculation was not successful or data is missing.")

The plot above shows the Voronoi cells for each k-space point (specifically, for the unique points used in the Voronoi calculation). The cells are colored by their area, providing a visual representation of sampling density. `show_points=True` displays the k-space sample locations, while `show_vertices=False` hides the Voronoi vertices for clarity in this example. The `cmap` parameter sets the colormap.

## 5. Accessing K-space and Gradient Data

In [None]:
print("Accessing k-space points directly:")
k_points = traj_sample.kspace_points_rad_per_m
print(f"Shape of k_points: {k_points.shape} (Dimensions, NumPoints)")
print(f"First 5 k-space points (transposed for readability):\n{k_points[:,:5].T}")

print("\nAccessing gradient waveforms (computed on demand if not provided):")
grad_wf = traj_sample.get_gradient_waveforms_Tm()
if grad_wf is not None:
    print(f"Shape of grad_wf: {grad_wf.shape} (Dimensions, NumPoints)")
    print(f"First 5 gradient waveform points (transposed for readability):\n{grad_wf[:,:5].T}")
else:
    print("Gradient waveforms are not available (e.g., dt_seconds might be missing).")

This notebook has covered several key features of the `Trajectory` class, demonstrating its utility in analyzing and managing k-space trajectory data.

## 6. Using Predefined Gyromagnetic Ratios

The `trajgen` module provides a dictionary `COMMON_NUCLEI_GAMMA_HZ_PER_T` for convenient access to gyromagnetic ratios of common nuclei.

In [None]:
print("Available nuclei in COMMON_NUCLEI_GAMMA_HZ_PER_T:")
for nucleus, gamma_val in COMMON_NUCLEI_GAMMA_HZ_PER_T.items():
    print(f"- {nucleus}: {gamma_val:.3e} Hz/T")

# Example: Generate a 2D spiral for 1H (proton)
fov_common = 0.200  # 20 cm
res_common = 0.002 # 2 mm
dt_common = 4e-6     # 4 us

gen_1H = KSpaceTrajectoryGenerator(
    fov=fov_common,
    resolution=res_common,
    dt=dt_common,
    gamma=COMMON_NUCLEI_GAMMA_HZ_PER_T['1H'],
    traj_type='spiral',
    dim=2,
    n_interleaves=1,
    turns=8
)
print(f"\n1H Spiral: Calculated n_samples = {gen_1H.n_samples}, g_required = {gen_1H.g_required:.4f} T/m")

# Example: Generate a 2D spiral for 13C with the same imaging parameters
gen_13C = KSpaceTrajectoryGenerator(
    fov=fov_common,
    resolution=res_common,
    dt=dt_common,
    gamma=COMMON_NUCLEI_GAMMA_HZ_PER_T['13C'],
    traj_type='spiral',
    dim=2,
    n_interleaves=1,
    turns=8
)
print(f"13C Spiral: Calculated n_samples = {gen_13C.n_samples}, g_required = {gen_13C.g_required:.4f} T/m")

print("\nNote the difference in n_samples and g_required due to different gamma values.")
print("A lower gamma (like 13C) typically requires stronger/longer gradients (higher g_required if g_max allows, or more samples) to achieve the same k_max.")

Using this dictionary helps ensure accuracy and convenience when working with different nuclei.