In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

import numpy as np
import torch
import zarr as z

sys.path.insert(0, "../../src")

from juart.conopt.functional.fourier import (
    nonuniform_fourier_transform_adjoint,
)
from juart.conopt.tfs.fourier import nonuniform_transfer_function
from juart.recon.sense import SENSE
from juart.vis.interactive import InteractiveFigure3D

store = z.open("/home/jovyan/datasets/num_phantom_128_R1")

C = torch.from_numpy(np.array(store["C"]))
k = torch.from_numpy(np.array(store["k"]))[..., None, None]
d = torch.from_numpy(np.array(store["d"]))[..., None, None]

In [None]:
kspace_mask_worker0 = torch.randint(0, 2, (1, k.shape[1], 1, 1))
kspace_mask_worker1 = 1 - kspace_mask_worker0

In [None]:
k_scaled_masked = k * kspace_mask_worker0
AHd = nonuniform_fourier_transform_adjoint(k, d, (128, 128, 128))
AHd = torch.sum(torch.conj(C[..., None, None]) * AHd, dim=0)

In [None]:
print(k_scaled_masked.shape, AHd.shape, C.shape, d.shape)

In [None]:
InteractiveFigure3D(AHd[..., 0, 0].abs().numpy()).interactive

In [None]:
device = "cuda:2"
H = nonuniform_transfer_function(k, (1, 128, 128, 128), (2, 2, 2))

In [None]:
C2 = C[..., None]
AHd2 = AHd[None, ...]

In [None]:
cg_solver = SENSE(
    C2.to(device),
    AHd2.to(device),
    H.to(device),
    axes=(1, 2, 3),
    maxiter=200,
    verbose=True,
    device=device,
)

In [None]:
C2.shape, AHd2.shape, H.shape

In [None]:
C2.dtype, AHd2.dtype, H.dtype

In [None]:
cg_image = cg_solver.solve().view(torch.complex64).reshape((128, 128, 128))

In [None]:
cg_image.shape

In [None]:
InteractiveFigure3D(
    cg_image[...].cpu().abs().numpy(), cmap="gray", vmax=cg_image.abs().max()
).interactive