# Advanced Trajectory Generation Features

This notebook explores advanced capabilities of the `KSpaceTrajectoryGenerator`, including custom trajectory functions, time-varying parameters, generating 3D trajectories from 2D bases, and per-interleaf parameter overrides.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # For 3D plotting
from trajgen import KSpaceTrajectoryGenerator, Trajectory

# Ensure plots appear inline in the notebook
%matplotlib inline

## 1. Custom Trajectory Function

Users can define their own Python function to generate arbitrary k-space shapes.

In [None]:
def lissajous_2d(interleaf_idx, t_vector, n_points, **kwargs):
    """ A simple Lissajous figure generator. """
    # kwargs can be used to pass additional parameters like k_max, frequencies etc.
    k_max = kwargs.get('k_max', 1.0 / (2 * 0.004)) # Example k_max
    
    # Frequencies for Lissajous, could be varied per interleaf_idx or via kwargs
    a = 3 + interleaf_idx 
    b = 2 
    delta = np.pi / 2
    
    # Normalize t_vector to range for full Lissajous pattern, e.g., 0 to 2*pi
    t_norm = np.linspace(0, 2 * np.pi, n_points)
    
    kx_custom = k_max * np.sin(a * t_norm + delta)
    ky_custom = k_max * np.sin(b * t_norm)
    
    # Gradients (simple numerical differentiation for example)
    dt_sec = kwargs.get('dt', 4e-6) # Should be passed from generator
    gamma = kwargs.get('gamma', 42.576e6)
    
    gx_custom = np.gradient(kx_custom, dt_sec) / gamma
    gy_custom = np.gradient(ky_custom, dt_sec) / gamma
    
    return (kx_custom, ky_custom), (gx_custom, gy_custom)

# Instantiate generator with the custom function
gen_custom = KSpaceTrajectoryGenerator(
    custom_traj_func=lissajous_2d,
    dim=2,
    fov=0.256, # These are still used for n_samples, k_max estimates if not overridden in custom func
    resolution=0.004,
    n_interleaves=2, # Generate 2 different Lissajous figures
    dt=4e-6, # Pass dt for gradient calculation in custom function
    gamma=42.576e6 # Pass gamma
)

# Generate trajectory waveforms
kx_c, ky_c, gx_c, gy_c, t_c = gen_custom.generate()

# Combine for Trajectory object
kspace_custom_2d = np.stack([kx_c.ravel(), ky_c.ravel()])
gradients_custom_2d = np.stack([gx_c.ravel(), gy_c.ravel()])

traj_custom = Trajectory(
    name='Custom Lissajous 2D',
    kspace_points_rad_per_m=kspace_custom_2d,
    gradient_waveforms_Tm=gradients_custom_2d,
    dt_seconds=gen_custom.dt,
    metadata={'gamma_Hz_per_T': gen_custom.gamma, 'generator_params': gen_custom.__dict__}
)

traj_custom.summary()

# Plot the custom trajectories
plt.figure(figsize=(10, 5))
for i in range(gen_custom.n_interleaves):
    plt.subplot(1, gen_custom.n_interleaves, i + 1)
    plt.plot(kx_c[i, :], ky_c[i, :])
    plt.title(f'Custom Interleaf {i+1}')
    plt.xlabel('Kx'); plt.ylabel('Ky'); plt.axis('equal')
plt.tight_layout()
plt.show()

## 2. Time-Varying Parameters

Trajectory parameters like FOV or resolution can be varied over time using a callback function.

In [None]:
def time_varying_fov(time_s):
    """ Example: Linearly increase FOV over a 10ms acquisition. """
    max_time = 0.010 # 10 ms
    base_fov = 0.200 # meters
    max_fov = 0.300  # meters
    
    current_fov = base_fov + (max_fov - base_fov) * (time_s / max_time)
    return {'fov': current_fov} # Must return a dictionary

gen_tv_params = KSpaceTrajectoryGenerator(
    traj_type='spiral',
    dim=2,
    fov=0.200, # Initial FOV
    resolution=0.005,
    n_interleaves=1,
    turns=10,
    time_varying_params=time_varying_fov
)

# Adjust n_samples if necessary, as FOV changes affect k_max indirectly
# For this demo, we'll use the n_samples based on initial FOV.
# A more advanced version might estimate max n_samples needed.
print(f"Generator n_samples (based on initial FOV): {gen_tv_params.n_samples}")

kx_tv, ky_tv, gx_tv, gy_tv, t_tv = gen_tv_params.generate()

traj_tv = Trajectory(
    name='Time-Varying FOV Spiral',
    kspace_points_rad_per_m=np.stack([kx_tv[0], ky_tv[0]]),
    gradient_waveforms_Tm=np.stack([gx_tv[0], gy_tv[0]]),
    dt_seconds=gen_tv_params.dt,
    metadata={'gamma_Hz_per_T': gen_tv_params.gamma}
)
traj_tv.summary()

plt.figure(figsize=(6,6))
plt.plot(kx_tv[0,:], ky_tv[0,:], '-')
plt.title('Spiral with Time-Varying FOV')
plt.xlabel('Kx (rad/m)'); plt.ylabel('Ky (rad/m)'); plt.axis('equal')
plt.show()

## 3. Generating 3D Trajectories from 2D Bases

The `generate_3d_from_2d` method allows creating 3D trajectories by rotating a 2D base trajectory (e.g., spiral) around different axes in 3D space.

In [None]:
# Base 2D spiral generator (will not be used directly for generation, but for its params)
base_gen_for_3d = KSpaceTrajectoryGenerator(
    fov=0.200,
    resolution=0.008,
    traj_type='spiral', # This is the 2D base type
    turns=5,
    vd_method='power', vd_alpha=1.0 # Example VD for the 2D spiral base
)

n_3d_shots = 32 # Number of 2D spiral orientations in 3D

kx_3d_from_2d, ky_3d_from_2d, kz_3d_from_2d, \
gx_3d_from_2d, gy_3d_from_2d, gz_3d_from_2d, t_3d_from_2d = base_gen_for_3d.generate_3d_from_2d(
    n_3d_shots=n_3d_shots,
    traj2d_type='spiral' # Explicitly state the 2D base type to use
    # Can also pass fov_3d, resolution_3d if different from base_gen_for_3d's settings
)

kspace_3d_rot = np.stack([kx_3d_from_2d.ravel(), ky_3d_from_2d.ravel(), kz_3d_from_2d.ravel()])
gradients_3d_rot = np.stack([gx_3d_from_2d.ravel(), gy_3d_from_2d.ravel(), gz_3d_from_2d.ravel()])

traj_3d_rot = Trajectory(
    name='3D from Rotated 2D Spirals',
    kspace_points_rad_per_m=kspace_3d_rot,
    gradient_waveforms_Tm=gradients_3d_rot,
    dt_seconds=base_gen_for_3d.dt,
    metadata={'gamma_Hz_per_T': base_gen_for_3d.gamma}
)
traj_3d_rot.summary()

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
step = max(1, kspace_3d_rot.shape[1] // 1000)
ax.plot(kspace_3d_rot[0,::step], kspace_3d_rot[1,::step], kspace_3d_rot[2,::step], '.', markersize=1)
ax.set_title('3D from Rotated 2D Spirals (Subset)')
ax.set_xlabel('Kx'); ax.set_ylabel('Ky'); ax.set_zlabel('Kz')
plt.show()

## 4. Per-Interleaf Parameters

Allows overriding generator parameters for specific interleaves.

In [None]:
# Example: Varying number of turns for different spiral interleaves
per_interleaf_settings = {
    0: {'turns': 4},   # First interleaf: 4 turns
    1: {'turns': 8},   # Second interleaf: 8 turns
    2: {'turns': 12},  # Third interleaf: 12 turns
    # Other interleaves will use the default 'turns' from the generator
}

gen_per_interleaf = KSpaceTrajectoryGenerator(
    traj_type='spiral',
    dim=2,
    fov=0.256,
    resolution=0.004,
    n_interleaves=4, 
    turns=6, # Default turns
    per_interleaf_params=per_interleaf_settings,
    use_golden_angle=False # Easier to see effect if angles are fixed for this demo
)

kx_pi, ky_pi, gx_pi, gy_pi, t_pi = gen_per_interleaf.generate()

traj_per_interleaf = Trajectory(
    name='Per-Interleaf Spiral Turns',
    kspace_points_rad_per_m=np.stack([kx_pi.ravel(), ky_pi.ravel()]),
    gradient_waveforms_Tm=np.stack([gx_pi.ravel(), gy_pi.ravel()]),
    dt_seconds=gen_per_interleaf.dt
)
traj_per_interleaf.summary()

plt.figure(figsize=(12, 5))
for i in range(gen_per_interleaf.n_interleaves):
    plt.subplot(1, gen_per_interleaf.n_interleaves, i + 1)
    plt.plot(kx_pi[i,:], ky_pi[i,:])
    actual_turns = per_interleaf_settings.get(i, {}).get('turns', gen_per_interleaf.turns)
    plt.title(f'Interleaf {i+1} (Turns: {actual_turns})')
    plt.xlabel('Kx'); plt.ylabel('Ky'); plt.axis('equal')
plt.tight_layout()
plt.show()

This notebook demonstrated several advanced features for fine-grained control over trajectory generation, enabling complex and customized k-space sampling patterns.