In [None]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.io import loadmat

from juart.conopt.functional.fourier import (
    nonuniform_fourier_transform_adjoint,
    nonuniform_fourier_transform_forward,
)
from juart.conopt.linops.identity import IdentityOperator
from juart.conopt.linops.tf import TransferFunctionOperator
from juart.conopt.proxops.linear import conjugate_gradient
from juart.conopt.tfs.fourier import nonuniform_transfer_function
from juart.ellipsoid_phantoms.ellipsoids import SheppLogan
from juart.recon.ncgrappa import NonCartGrappa

torch.manual_seed(42)

In [2]:
def unique(x, dim=None):
    """Unique elements of x and indices of those unique elements
    https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810

    e.g.

    unique(tensor([
        [1, 2, 3],
        [1, 2, 4],
        [1, 2, 3],
        [1, 2, 5]
    ]), dim=0)
    => (tensor([[1, 2, 3],
                [1, 2, 4],
                [1, 2, 5]]),
        tensor([0, 1, 3]))
    """
    unique, inverse = torch.unique(x, sorted=True, return_inverse=True, dim=dim)
    perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
    inverse, perm = inverse.flip([0]), perm.flip([0])
    return unique, inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)

In [3]:
# Load spiral dataset
data = loadmat("spiral_data.mat")

# Extract features and labels
kernel_size = data["pSize"].squeeze()
matrix_size = data["imSize"].squeeze()
ksp_ref = data["kcalib"]
ktraj_us_comb = data["usTraj"]
ksp_us = data["ukPS"]

# Convert to data format of juart
ksp_ref = np.moveaxis(ksp_ref, -1, 0)
ksp_us = np.moveaxis(ksp_us, -1, 0)
ktraj_us_comb = np.moveaxis(ktraj_us_comb, 0, -1)

# Convert to torch tensors
ksp_ref = torch.tensor(ksp_ref, dtype=torch.complex64)
ksp_us = torch.tensor(ksp_us, dtype=torch.complex64)
ktraj_us_comb = torch.tensor(ktraj_us_comb, dtype=torch.float32)

ktraj_us_comb = torch.roll(ktraj_us_comb, shifts=2, dims=0)

sample_mask = ktraj_us_comb[2] != 0
ktraj_sampled = ktraj_us_comb[:2, sample_mask]
ktraj_unsampled = ktraj_us_comb[:2, ~sample_mask]
ksp_sampled = ksp_us[:, sample_mask]

# Non Cartesian GRAPPA reconstruction

In [4]:
# from juart.recon.ncgrappa import NonCartGrappa

# ncgrappa = NonCartGrappa(
#     ktraj = ktraj_us_comb,
#     calib_signal = ksp_ref,
#     img_size=[200, 200],
#     kernel_size=7,
#     do_sift=True,

# )

In [None]:
# Find neigbours with KDTree
from scipy.spatial import KDTree

norm_factor = kernel_size / 2.0
kdtree_sampled = KDTree((ktraj_sampled / norm_factor[:, None]).cpu().numpy().T)
kdtree_unsampled = KDTree((ktraj_unsampled / norm_factor[:, None]).cpu().numpy().T)

# radius = kernel_size[0] / 2.0  # Radius in cycle/fov units
radius = 1.0 + 1e-6  # Buffer for floating point precision

neighbors = kdtree_unsampled.query_ball_tree(kdtree_sampled, r=radius, p=np.inf)

inds_c = [[i] + n for i, n in enumerate(neighbors)]

print(f"Found {[ind for ind in inds_c if len(ind) > 0]} neighbors")

In [None]:
num_locations = 5
patch_indices = torch.randint(0, int(torch.sum(ktraj_us_comb[-1])), (num_locations,))

colors = plt.cm.viridis(np.linspace(0, 1, num_locations + 1))

for _enum, idx in enumerate(patch_indices[2:3]):
    patch_neigh = ktraj_sampled[:, inds_c[idx][1:]]
    patch_cent = ktraj_unsampled[:, inds_c[idx][0]]

    # Sift
    int_shifts = torch.round(patch_neigh - patch_cent[:, None], decimals=0)

    unq_dist, idx, counts = torch.unique(
        int_shifts, dim=1, sorted=True, return_inverse=True, return_counts=True
    )
    _, ind_sorted = torch.sort(idx, stable=True)
    cum_sum = counts.cumsum(0)
    cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
    unique_indices = ind_sorted[cum_sum]

    plt.figure()
    plt.scatter(
        unq_dist[0].numpy(),
        unq_dist[1].numpy(),
        marker="o",
        c=colors[_enum],
        label="unique shifts" if _enum == 0 else "",
    )
    # unique_indices = unique_indices.unique()

    # patch_neigh_unq = unq_dist + patch_cent[:, None]
    patch_neigh_unq = patch_neigh[:, unique_indices]

    plt.figure()
    plt.scatter(
        patch_neigh[0].numpy(),
        patch_neigh[1].numpy(),
        marker="|",
        c=colors[_enum],
        label="sampled location" if _enum == 0 else "",
    )
    plt.scatter(
        patch_cent[0].numpy(),
        patch_cent[1].numpy(),
        marker="v",
        c=colors[_enum],
        label="unsampled location" if _enum == 0 else "",
    )
    plt.scatter(
        ktraj_us_comb[0, ktraj_us_comb[-1] == 1].numpy(),
        ktraj_us_comb[1, ktraj_us_comb[-1] == 1].numpy(),
        s=1,
        marker=".",
        c="gray",
        alpha=0.5,
    )
    plt.scatter(
        patch_neigh_unq[0].numpy(),
        patch_neigh_unq[1].numpy(),
        marker="_",
        c=colors[_enum],
        label="unique sampled locations" if _enum == 0 else "",
    )


plt.title("Trajectory Patch")
plt.xlabel("kx")
plt.ylabel("ky")
plt.legend()
plt.show()

In [None]:
# Scale trajectory from -0.5 to 0.5
ncgrappa_operator = NonCartGrappa(
    ktraj=(ktraj_us_comb / torch.tensor([*matrix_size[:2], 1.0])[:, None]),
    calib_signal=ksp_ref.squeeze(),
    kernel_size=7,
    img_size=matrix_size[:2],
    verbose=5,
    shift_tol=1e-6,
    do_sift=True,
    tik=0,
)

In [None]:
# Scale trajectory from -0.5 to 0.5
from juart.recon.ncgrappa_2 import NonCartesianGrappa

ncgrappa_operator_2 = NonCartesianGrappa(
    ktraj=ktraj_us_comb,
    calib_signal=ksp_ref[..., 0],
    kernel_size=torch.tensor([7, 7]),
    verbose=5,
    do_sift=True,
    tik=0,
)

In [None]:
ksp_filled = ncgrappa_operator.apply(ksp_us.clone())

In [None]:
np.savez(
    "kspace_filled_and_traj.npz",
    ksp_filled=ksp_filled.cpu().numpy(),
    ktraj_us_comb=ktraj_us_comb.cpu().numpy(),
)

In [None]:
# Compare filled and unsampled k-space

vmin = min(
    torch.min(torch.abs(ksp_us[0])),
    torch.min(torch.abs(ksp_filled[0])),
)
vmin = 2
vmax = max(torch.max(torch.abs(ksp_us[0])), torch.max(torch.abs(ksp_filled[0])))
fig, axes = plt.subplots(1, 3, figsize=(8, 2))

axes[0].scatter(
    ktraj_us_comb[0],
    ktraj_us_comb[1],
    c=torch.abs(ksp_us[0]) + 1000,
    vmin=vmin,
    vmax=vmax,
    s=1,
)
axes[0].set_title("Unsampled k-space")

axes[1].scatter(
    ktraj_us_comb[0],
    ktraj_us_comb[1],
    c=torch.abs(ksp_filled[0]) + 1000,
    vmin=vmin,
    vmax=vmax,
    s=1,
)
axes[1].set_title("Filled k-space")

axes[2].scatter(
    ktraj_us_comb[0],
    ktraj_us_comb[1],
    c=torch.abs(ksp_filled[0] - ksp_us[0]) + 1000,
    vmin=vmin,
    vmax=vmax,
    s=1,
    norm="log",
)
axes[2].set_title("Difference")

for ax in axes:
    ax.set_xlabel("kx")
    ax.set_ylabel("ky")
    ax.axis("equal")

plt.tight_layout()
plt.show()

## Reconstruction

In [None]:
# Scale k from -0.5 to 0.5
k = ktraj_us_comb[:-1] / 200

In [None]:
transfer_function = nonuniform_transfer_function(k, data_shape=(1, 200, 200, 1))

regridded_data = nonuniform_fourier_transform_adjoint(
    k,
    ksp_filled,
    n_modes=(200, 200),
    modeord=0,
)

transfer_function_operator = TransferFunctionOperator(
    transfer_function, regridded_data.shape, axes=(1, 2)
)

ident_operator = IdentityOperator(
    regridded_data.shape,
)

In [None]:
# Calculate CG-NUFFT solution with regularization
reg_param = 0.002
d_vec = regridded_data.view(torch.float32).ravel()
init_guess = torch.rand(d_vec.shape, dtype=torch.float32)
ATA = transfer_function_operator + reg_param * ident_operator

img, _ = conjugate_gradient(A=ATA, b=d_vec, x=init_guess, maxiter=50, residual=[])

img = img.view(torch.complex64).reshape(regridded_data.shape)

img_rss = torch.sqrt(torch.sum(torch.abs(img) ** 2, dim=0))

In [None]:
plt.figure()
plt.imshow(img_rss[:, :, 0] ** 2)

# Non Cartesian GRAPPA on Spiral Waveforms

In [None]:
def archimedean_spiral(num_points, FOV, matrix_size, N=1):
    """
    Generate (x, y) coordinates for an Archimedean spiral in k-space.

    Parameters:
    - num_points: Number of points in the spiral.
    - FOV: Field of view in meters.
    - matrix_size: Size of the matrix (e.g., 200 for a 200x200 matrix).
    - N: Number of spirals to generate.

    Returns:
    - k: A 2D array of shape (2, num_points, N) containing the spiral coordinates.
    """
    # Compute theta_max from your stopping point equation
    k_max = matrix_size / FOV / 2

    dtheta = 2 * torch.pi / N
    theta_max = (2 * torch.pi / N) * k_max * FOV
    theta = torch.linspace(0, theta_max, num_points)

    # Spiral radius as a function of theta
    r = (N * theta) / (2 * torch.pi * FOV)

    k = torch.zeros((2, num_points, N), dtype=torch.float32)

    for n in range(N):
        # Spiral in the complex plane
        x = r * torch.cos(theta + (dtheta * n))
        y = r * torch.sin(theta + (dtheta * n))

        k[0, :, n] = x
        k[1, :, n] = y
    return k

# Create Shepp Logan Phantom


In [None]:
# Create a Shepp-Logan phantom
sl_phantom = SheppLogan(fov=[0.2, 0.2], matrix=[200, 200])
# Add coil for sensitivity maps
sl_phantom.add_coil()

# Get image space object
img = sl_phantom.get_object()

# Remove echo dimension
img = img[..., 0]

# Show the Shepp-Logan phantom with coil sensitivity
plt.figure(figsize=(8, 2))
plt.title("Shepp-Logan Phantom")
plt.imshow(img.reshape(-1, img.shape[1]).abs().numpy().T, origin="lower")
plt.tight_layout()

# Create spiral trajectory data and ACS region

In [None]:
ktraj = archimedean_spiral(num_points=6000, FOV=0.2, matrix_size=200, N=4)

# Scale trajectory from -0.5 to 0.5
ktraj = ktraj / (2 * ktraj.max())

ksp_spiral = nonuniform_fourier_transform_forward(
    k=ktraj.reshape(ktraj.shape[0], -1), x=img
)
ksp_spiral = ksp_spiral.reshape(img.shape[0], -1, ktraj.shape[-1])