In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

sys.path.insert(0, "/home/jovyan/juart/src")

import math

import matplotlib.pyplot as plt
import numpy as np
import torch

from juart.conopt.functional.fourier import (
    nonuniform_fourier_transform_adjoint,
)
from juart.conopt.linops.identity import IdentityOperator
from juart.conopt.linops.tf import TransferFunctionOperator
from juart.conopt.proxops.linear import conjugate_gradient
from juart.conopt.tfs.fourier import nonuniform_transfer_function

Use MR0 to simulate 

In [None]:
%pip install git+https://github.com/mrphysics-bonn/spiraltraj.git
%pip install MRzeroCore
%pip install pypulseq

import MRzeroCore as mr0
import pypulseq as pp
import spiraltraj

In [None]:
help(spiraltraj)

# Build simple 2D single-shot spiral-out GRE

In [None]:
# choose the scanner limits
system = pp.Opts(
    max_grad=28,
    grad_unit="mT/m",
    max_slew=150,
    slew_unit="T/m/s",
    rf_ringdown_time=20e-6,
    rf_dead_time=100e-6,
    adc_dead_time=0,
    grad_raster_time=10e-6,
    adc_samples_limit=8192,
    adc_samples_divisor=4,
)

seq = pp.Sequence()

# Define FOV and resolution
slice_thickness = 3e-3
te_fill = 10e-3
fov = [0.24, 0.24, slice_thickness]
matrix = [96, 96, 1]

# Define rf events
rf, gz, gzr = pp.make_sinc_pulse(
    flip_angle=90 * np.pi / 180,
    duration=2e-3,
    slice_thickness=slice_thickness,
    apodization=0.5,
    time_bw_product=4,
    system=system,
    return_gz=True,
)

# Define readout trajectory
grad_samples = spiraltraj.calc_traj(
    nitlv=1,
    res=fov[0] / matrix[0] * 1e3,  # Resolution in mm
    fov=fov[0] * 1e3,  # FOV in mm
    max_amp=system.max_grad / system.gamma * 1e3,  # max gradient in mT/m
    min_rise=1 / (system.max_slew / system.gamma) * 1e3,  # min rise time in us/mT/m
    spiraltype=1,
)

grad_samples = np.array(grad_samples).T * 1e-3  # convert to T/m

read_grads = []
for i, wave in enumerate(grad_samples):
    grad = pp.make_arbitrary_grad(
        channel="x" if i == 0 else "y",
        waveform=wave * system.gamma,
        first=0,
        last=0,
    )
    read_grads.append(grad)

# Define readout adc

# Check maximal dwell for nyquist sampling
max_g = np.max(np.linalg.norm([g.waveform for g in read_grads], axis=0)) / system.gamma
min_dwell = 1 / (system.gamma * max_g * fov[0])
min_dwell = (
    math.floor(min_dwell / system.adc_raster_time) * system.adc_raster_time
)  # Round to adc raster time

# Get total number of samples for readout
num_samples = math.ceil(pp.calc_duration(*read_grads) / min_dwell)

# Create segmented adc object
num_adc_seg, num_seg_samples = pp.calc_adc_segments(
    num_samples=num_samples,
    dwell=min_dwell,
    system=system,
)

adc = pp.make_adc(
    num_samples=num_seg_samples * num_adc_seg, dwell=min_dwell, system=system
)
adc_samples_overlap = adc.num_samples - num_seg_samples * num_adc_seg

# Create sequence
seq.add_block(rf, gz)
seq.add_block(gzr)

seq.add_block(pp.make_delay(te_fill))

seq.add_block(*read_grads, adc)

seq.write("spiral.seq")

# Create MR0 Phantom

In [None]:
url = "https://github.com/MRsources/MRzero-Core/raw/main/documentation/playground_mr0/subject05.npz"

phantom = mr0.util.load_phantom(url=url)

phantom.size = torch.tensor(fov)
phantom = phantom.interpolate(*matrix)

phantom.B0 *= 1.5

phantom.plot()
data = phantom.build()

# Simulate Signal

In [None]:
seq_mr0 = mr0.Sequence.import_file("spiral.seq")
graph = mr0.compute_graph(
    seq=seq_mr0, data=data, max_state_count=200, min_state_mag=1e-3
)
signal = mr0.execute_graph(graph=graph, seq=seq_mr0, data=data, print_progress=True)

# Reconstruction with CG-NUFFT

Prepare kspace trajectory and signal for JUART reconstruction

In [None]:
# Get kpsace trajectory
k, _, _, _, _ = seq.calculate_kspace()
k *= fov[0] / matrix[0]  # Scale from -0.5 to 0.5

# We can save computation time by limiting trajectory to 2D
k = k[:2]

signal = torch.moveaxis(signal, 0, 1)  # Move channel axis to first axis

k = torch.tensor(k, dtype=torch.float32)

In [None]:
from juart.recon.offres import OffResonanceCorrection

In [None]:
B0 = phantom.B0[None, ...].numpy()
# TODO: Check why this is necessary
B0 = B0[:, ::-1, ::-1, :]

In [None]:
num_seg = 10
num_samples = k.shape[1]

offres = OffResonanceCorrection(B0, num_seg, num_samples, min_dwell)

In [None]:
weights = [torch.from_numpy(offres.get_signal_weights(i)) for i in range(num_seg)]

In [None]:
phases = [torch.from_numpy(offres.get_img_phase(i)) for i in range(num_seg)]

Perform CG-NUFFT

In [None]:
transfer_function = nonuniform_transfer_function(
    k, data_shape=(1, matrix[0], matrix[1], 1)
)

regridded_data = nonuniform_fourier_transform_adjoint(
    k,
    signal,
    n_modes=tuple(matrix[:2]),
    modeord=0,
)

transfer_function_operator = TransferFunctionOperator(
    transfer_function, regridded_data.shape, axes=(1, 2)
)

ident_operator = IdentityOperator(
    regridded_data.shape,
)

In [None]:
from juart.conopt.linops.offres import OffresonaceTransferFunctionOperator

In [None]:
transfer_functions = list()

for weight in weights:
    transfer_functions.append(
        nonuniform_transfer_function(
            k,
            data_shape=(1, matrix[0], matrix[1], 1),
            weights=weight,
        )
    )

In [None]:
offres_transfer_function_operator = OffresonaceTransferFunctionOperator(
    transfer_functions, phases, regridded_data.shape, axes=(1, 2)
)

In [None]:
regridded_data_offres = torch.zeros_like(regridded_data)

for weight, phase in zip(weights, phases):
    regridded_data_offres += torch.conj(phase) * nonuniform_fourier_transform_adjoint(
        k,
        weight * signal,
        n_modes=tuple(matrix[:2]),
        modeord=0,
    )

In [None]:
# Calculate CG-NUFFT solution with regularization
reg_param = 0.001
d_vec = regridded_data.view(torch.float32).ravel()
init_guess = torch.zeros(d_vec.shape, dtype=torch.float32)
ATA = transfer_function_operator + reg_param * ident_operator

img = conjugate_gradient(
    A=ATA,
    b=d_vec,
    residual=[],
    x=init_guess,
    maxiter=40,
    verbose=True,
)[0]

img = img.view(torch.complex64).reshape(regridded_data.shape)

In [None]:
# Calculate CG-NUFFT solution with regularization
reg_param = 0.001
d_vec = regridded_data_offres.ravel().view(torch.float32)
init_guess = torch.zeros(d_vec.shape, dtype=torch.float32)
ATA = offres_transfer_function_operator + reg_param * ident_operator

img_offres = conjugate_gradient(
    A=ATA,
    b=d_vec,
    residual=[],
    x=init_guess,
    maxiter=40,
    verbose=True,
)[0]

img_offres = img_offres.view(torch.complex64).reshape(regridded_data.shape)

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(torch.abs(img[0, ..., 0]), vmin=0, vmax=150, cmap="gray")
plt.subplot(1, 2, 2)
plt.imshow(torch.abs(img_offres[0, ..., 0]), vmax=150, cmap="gray")