In [None]:
import math

import numpy as  np
import scipy.io as sio
import torch
import yaml

from effop import FiniteDifferenceL1, SenseNufftOp, PrimalDualL1

device = torch.device("cpu")  # use this for CPU
# device = torch.device("cuda")  # use this for GPU

In [None]:
# load the data
with open("../data_loc.yaml", "r") as f:
    data_file = yaml.safe_load(f)

raw_data = sio.loadmat(data_file)
dcomp = torch.tensor(raw_data["w"]).permute(1, 0)
# precompensate k-space with density compensation
# kdata = torch.tensor(raw_data["kdata"]).permute(1, 0, 2) * dcomp.sqrt().unsqueeze(-1)
kdata = torch.tensor(raw_data["kdata"]).permute(1, 0, 2)
ktraj = torch.tensor(raw_data["k"]).permute(1, 0) * 2 * np.pi
sensitivity_maps = (
    torch.tensor(np.transpose(raw_data["b1"], (2, 1, 0))).unsqueeze(0).contiguous()
).to(device)
sensitivity_maps = sensitivity_maps / sensitivity_maps.abs().max()

# resort k-space based on temporal resolution
nspokes = 21
num_timepoints = math.floor(kdata.shape[0] / nspokes)
num_coils = kdata.shape[-1]
kdata = (
    kdata[: nspokes * num_timepoints]
    .reshape(num_timepoints, -1, num_coils)
    .permute(0, 2, 1)
    .contiguous()
).to(device)
dcomp = dcomp * dcomp.shape[0] / nspokes
dcomp = torch.real(dcomp[: nspokes * num_timepoints].reshape(num_timepoints, 1, -1).contiguous()).to(device)
ktraj = ktraj[: nspokes * num_timepoints].reshape(num_timepoints, -1).contiguous()
ktraj = torch.stack((torch.imag(ktraj), torch.real(ktraj)), dim=1).contiguous().to(device)

In [None]:
# create the operators
data_op = SenseNufftOp(sensitivity_maps, ktraj).to(device)

# initial estimate
with torch.no_grad():
    orig_est = data_op.adjoint(dcomp * kdata) / torch.sum(
        sensitivity_maps.abs() ** 2, dim=1, keepdim=True
    )

reg_op = FiniteDifferenceL1(lam=0.25 * orig_est.abs().max()).to(device)

In [None]:
# set this to None to estimate the step size
sense_eig = 14.5
# run a power iteration
if sense_eig is None:
    vec = torch.randn_like(orig_est)
    for ite in range(30):
        vec = data_op.adjoint(dcomp * data_op.forward(vec / torch.norm(vec)))
        print(torch.norm(vec))

    sense_eig = torch.norm(vec)

# this is analytical
reg_eig = 4.0

In [None]:
import matplotlib.pyplot as plt

frac_high = 0.7
frac_low = 0.1
new_est = orig_est / orig_est.abs().max()
new_est[new_est.abs() > frac_high] = frac_high
new_est[new_est.abs() < frac_low] = 0.0

plt.figure(0)
plt.imshow(new_est[0][0].abs().permute(1, 0).flip(0).cpu().numpy())
plt.xticks([])
plt.yticks([])
plt.gray()

# plt.figure(1)
# plt.imshow(torch.sum(sensitivity_maps.abs() ** 2, dim=1)[0].permute(1, 0).flip(0).numpy())
# plt.xticks([])
# plt.yticks([])
# plt.show()

In [None]:
# create the optimizer
opt = PrimalDualL1(
    data_operator=data_op,
    data_bound=sense_eig,
    reg_operator=reg_op,
    reg_bound=reg_eig,
    num_iterations=8,
    data_weights=dcomp,
)

# optimize!
with torch.no_grad():
    est = opt.solve(kdata, orig_est)

In [None]:
%matplotlib inline
from matplotlib import animation
from IPython.display import HTML

plt.figure(0)
frac_high = 0.7
frac_low = 0.1
new_est = est / est.abs().max()
new_est[new_est.abs() > frac_high] = frac_high
new_est[new_est.abs() < frac_low] = 0.0
new_est = new_est.permute(0, 3, 2, 1).flip(1)
new_est = torch.floor(new_est.abs() * 255.0 + 0.5).to(torch.uint8).repeat(1, 1, 1, 3)
video = new_est.cpu().numpy()
print(video.shape)

# confert to video)
fig = plt.figure(0)
fig.set_size_inches(5, 5)
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
im = ax.imshow(video[0, ...], aspect="equal")
plt.close()

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i, ...])
    return im

anim = animation.FuncAnimation(
    fig,
    animate,
    init_func=init,
    frames=video.shape[0],
    interval=100,
)
# anim.save("video.mp4")
HTML(anim.to_html5_video())