In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

import numpy as np
import torch
import matplotlib.pyplot as plt

sys.path.insert(0, "../../../src")

from juart.conopt.functional.fourier import nonuniform_fourier_transform_forward
from juart.parim.analytic import cyclic_head_coil
from juart.phantoms.mni import BrainPhantom5D
from juart.sampling.spherical import spherical_trajectory_3d
from juart.vis.interactive import InteractiveFigure4D
import zarr

In [None]:
phantom = BrainPhantom5D(
    B0=3,
    B0_shimming=True,
)

In [None]:
dTE, TE0, nTE = 5, 5, 1
dTI, TI0, nTI = 100, 20, 1

TE = TE0 + dTE * np.arange(nTE)
TI = TI0 + dTI * np.arange(nTI)

TR = 1e6
IE = 1

In [None]:
x_image = phantom.signal(TI, TE, TR, IE)
x_image = x_image / np.abs(x_image).max()
x_image = torch.from_numpy(x_image).to(torch.complex64)

In [None]:
k_unraveled = spherical_trajectory_3d(256, 256**2)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")

ax.scatter(
    k_unraveled[0, -1, :],
    k_unraveled[1, -1, :],
    k_unraveled[2, -1, :],
)

ax.set_xlabel("X Label")
ax.set_ylabel("Y Label")
ax.set_zlabel("Z Label")

plt.show()

In [None]:
k = k_unraveled.reshape((3, -1))

In [None]:
C = cyclic_head_coil((8, 256, 256, 256))

In [None]:
coil_images = C * x_image[None, ..., 0, 0]

In [None]:
InteractiveFigure4D(
    torch.abs(coil_images).numpy(),
    vmin=0,
    vmax=1,
    title="Contrast images",
    axes=(1, 2, 3, 0),
    cmap="gray",
    description=("Slice", "Channel"),
).interactive

In [None]:
d = nonuniform_fourier_transform_forward(
    k,
    coil_images,
)

In [None]:
store = zarr.storage.LocalStore("/home/jovyan/datasets/num_phantom_sph_traj")

In [None]:
group = zarr.create_group(
    store,
    overwrite=True,
)

In [None]:
group.create_array(
    "C",
    shape=C.shape,
    dtype=np.complex64,
    overwrite=True,
)
group.create_array(
    "k", 
    shape=k.shape,
    dtype=np.float32,
    overwrite=True,
)
group.create_array(
    "d", 
    shape=d.shape,
    dtype=np.complex64,
    overwrite=True,
)

In [None]:
group["C"] = C.numpy()

In [None]:
group["k"] = k.numpy()

In [None]:
group["d"] = d.numpy()