In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

sys.path.insert(0, "../../src")

import matplotlib.pyplot as plt
import numpy as np
import torch

from juart.dl.checkpoint.manager import CheckpointManager
from juart.dl.data.inference import DatasetInference
from juart.dl.model.dc import DataConsistency
from juart.dl.model.unrollnet import ExponentialMovingAverageModel, UnrolledNet
from juart.dl.operation.modules import inference
# from juart.vis import MultiPlotter

In [None]:
# To improve performance, manually limit the number of threads
# torch.set_num_threads(16)
# torch.set_num_interop_threads(16)

# Load data

In [None]:
dataset = DatasetInference(
    "qrage/sessions/%s/preproc.zarr/preproc.zarr",
    ["7T1026"],
    [80],
    8,
    endpoint_url="https://s3.fz-juelich.de",
    backend="s3",
    device="cpu",
)

In [None]:
data = dataset[0:1]

### Initialize model, load checkpoint and run inference

In [None]:
checkpoint_manager = CheckpointManager(
    "qrage/models/ssl_512_features_ddp/hankel_dual_domain_v35_epoch_15",
    endpoint_url="https://s3.fz-juelich.de",
    backend="s3",
)

In [None]:
nX, nY, nTI, nTE = (256, 256, 19, 9)

model = UnrolledNet(
    (nX, nY),
    contrasts=nTI * nTE,
    features=512,
    CG_Iter=10,
    num_unroll_blocks=10,
    spectral_normalization=False,
    activation="ReLU",
    disable_progress_bar=False,
    device="cpu",
)

checkpoint = checkpoint_manager.load(
    ["averaged_model_state", "iteration"], map_location="cpu"
)
model = ExponentialMovingAverageModel(model, 0.9)
model.load_state_dict(checkpoint["averaged_model_state"])
iteration = checkpoint["iteration"]
print(f"Loaded averaged at iteration {iteration}.")

In [None]:
images = inference(data, model, device="cpu")

In [None]:
images = inference(data, model, device="cpu")

### Compare to CG SENSE

In [None]:
dc_block = DataConsistency(
    (nX, nY),
    device="cpu",
)

In [None]:
dc_block.init(
    data["images_regridded"],
    data["kspace_trajectory"],
    sensitivity_maps=data["sensitivity_maps"],
)

In [None]:
with torch.no_grad():
    cg_sense = dc_block(data["images_regridded"])

### Display results

In [None]:
vmax = 3
iTI, iTE = 9, 4

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(np.abs(cg_sense[:, :, 0, iTI, iTE]), cmap="gray", vmax=vmax)
plt.title("CG-SENSE Recon")
plt.subplot(1, 2, 2)
plt.imshow(np.abs(images[:, :, 0, iTI, iTE]), cmap="gray", vmax=vmax)
plt.title("DL-QRAGE Recon")
plt.show()