In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

sys.path.insert(0, "../../src")

import numpy as np
import torch
import zarr as z

from juart.conopt.functional.fourier import (
    nonuniform_fourier_transform_adjoint,
)
from juart.dl.model.dc import DataConsistency
from juart.dl.model.resnet import ResNet
from juart.vis.interactive import InteractiveFigure3D, InteractiveMultiPlotter3D

In [None]:
kspace_cutoff = True
nX_cutoff, nY_cutoff, nZ_cutoff = 64, 64, 64
nX, nY, nZ, nTI, nTE = 128, 128, 128, 1, 1
shape = (nX, nY, nZ, nTI, nTE)

device = "cuda:3"

In [None]:
store = z.open("/home/jovyan/datasets/num_phantom_128_R1")

C = torch.from_numpy(np.array(store["C"]))
k = torch.from_numpy(np.array(store["k"]))[..., None, None]
d = torch.from_numpy(np.array(store["d"]))[..., None, None]

In [None]:
k_scaled = k / (2 * k.max())

generator = torch.Generator()
generator.manual_seed(0)
kspace_mask = torch.randint(0, 2, (1, k_scaled.shape[1], 1, 1), generator=generator)

In [None]:
AHd = nonuniform_fourier_transform_adjoint(k * kspace_mask, d, (nX, nY, nZ))
AHd = torch.sum(torch.conj(C[..., None, None]) * AHd, dim=0)
print(AHd.shape, C.shape, d.shape, (k * kspace_mask).shape)

In [None]:
InteractiveFigure3D(AHd[..., 0, 0].abs().numpy(), cmap="gray").interactive

In [None]:
data = {
    "images_regridded": AHd,
    "kspace_trajectory": k,
    "sensitivity_maps": C,
    "kspace_mask": kspace_mask,
}

In [None]:
dc_block = DataConsistency(shape, axes=(1, 2, 3), device=device, verbose=True, niter=80)

In [None]:
dc_block.init(
    data["images_regridded"],
    data["kspace_trajectory"],
    sensitivity_maps=data["sensitivity_maps"],
    kspace_mask=data["kspace_mask"],
)

In [None]:
print(
    data["images_regridded"].shape,
    data["kspace_trajectory"].shape,
    data["sensitivity_maps"].shape,
)

In [None]:
with torch.no_grad():
    cg_sense = dc_block(data["images_regridded"])

In [None]:
resnet_image = torch.squeeze(cg_sense, dim=(3, 4))
resnet_image_unsqueezed = resnet_image.unsqueeze(-1)
cg_image = resnet_image.abs()

In [None]:
InteractiveMultiPlotter3D(
    [
        torch.abs(cg_sense[..., 0, 0]).cpu().abs().numpy(),
        cg_sense2[..., 0, 0].cpu().abs().numpy(),
    ],
    title=["DC1", "DC2"],
    vmin=0,
    vmax=torch.abs(cg_sense[..., 0, 0]).abs().max(),
    cmap="gray",
).interactive

In [None]:
resnet = ResNet(
    contrasts=1,
    features=32,
    dim=3,
    num_of_resblocks=10,
    device=device,
    kernel_size=(3, 3, 3),
    ResNetCheckpoints=True,
)

image = resnet(resnet_image[..., None, None])
print(image.shape)

In [None]:
InteractiveMultiPlotter3D(
    image[..., 0, 0].cpu().abs().detach().numpy(), cmap="gray", vmax=10
).interactive