# Inference Notebook

## Import of all needed scripts

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

sys.path.insert(0, "../../src")

import matplotlib.pyplot as plt
import numpy as np
import torch

from juart.dl.checkpoint.manager import CheckpointManager
from juart.dl.data.inference import DatasetInference
from juart.dl.model.dc import DataConsistency
from juart.dl.model.resnet import ResNet
from juart.dl.model.unrollnet import ExponentialMovingAverageModel, UnrolledNet
from juart.dl.operation.modules import inference

In [None]:
# To improve performance, manually limit the number of threads
# torch.set_num_threads(16)
# torch.set_num_interop_threads(16)

## Define important variables

In [None]:
# dataset options
i_slice = [80]  # number of slice that will be reconstructed
num_spokes = 64  # number of spokes that are used for the reconstruction

# device options
device = "cpu"  # device on which the reconstructions will run

# CheckpointManager options
directory = "model_64spokes_1epoch"  # directory of the dl-qrage model
root_dir = "/home/jovyan/models"  # the path to the model directory
backend = "local"  # the backend of the model directory

# model options
nX, nY, nZ, nTI, nTE = (
    256,
    256,
    1,
    2,
    2,
)  # number of pixels in every direction // number of measurements at T1/T2 decay
shape = (
    nX,
    nY,
    nZ,
    nTI,
    nTE,
)  # ordered structure of the parameters above that will be passed to the model
features = 64  # number of hidden_inputs of the dl-qrage model
cg_iter = 10  # number of cg iterations in the dl-qrage model reconstruction

# display options
vmax = 2  # sets the brightness normalization of the display between 0 and vmax
iTI, iTE = 1, 0  # sets the number of measurement that should be displayed

## Load data
The DatasetInference is used to load a Dataset which is ready for inference. Next to the arguments the DatasetInference class will need to localize the data, it is also important to define the slice and the number of spokes which should be loaded from the dataset.

In [None]:
dataset = DatasetInference(
    "qrage/sessions/%s/preproc.zarr/preproc.zarr",
    ["7T1566"],
    i_slice,
    num_spokes,
    endpoint_url="https://s3.fz-juelich.de",
    backend="s3",
    device=device,
)

Extract the important data out of the dataset.

In [None]:
data = dataset[0:1]

# Initialize CheckpointManager and model and run the inference

## Setup CheckpointManager
The CheckpointManager is used to load a trained model. First we do the setup and tell him where he has to search for the model files and then we can load specific properties out of the models directory.<br><br>
#### Description of the complete CheckpointManager process:
In the training notebook the CheckpointManager gets the input where it should save the model after every n itearations. When the model is trained and saved at the specific location, the CheckpointManger in the Inference notebook can load this model out of this specific location. In this way it is not only possible to use the completely trained model after nmax iterations but it is also possible to use the interim results to check if the model is going the right way.<br><br>

In [None]:
checkpoint_manager = CheckpointManager(
    directory=directory,
    root_dir=root_dir,
    backend=backend,
)

checkpoint = checkpoint_manager.load(
    ["averaged_model_state", "iteration"], map_location="cpu"
)

## Define the model and load a trained one with the CheckpointManager
The model gets defined and calibrated for the specific image, that should be reconstructed. Afterwards a trained model will be loaded with the help of the CheckpointManager. Due to this, the parameters of the initialized model will be overwritten. At the end of the cell an output will be generated that gives information about the number of iterations the loaded model has been trained for.

In [None]:
model = UnrolledNet(
    shape,
    features=features,
    CG_Iter=cg_iter,
    num_unroll_blocks=10,
    spectral_normalization=False,
    activation="ReLU",
    disable_progress_bar=False,
    device=device,
)

model = ExponentialMovingAverageModel(model, 0.9)
model.load_state_dict(checkpoint["averaged_model_state"])

iteration = checkpoint["iteration"]
print(f"Loaded averaged at iteration {iteration}.")

## Image reconstruction with dl-qrage 
The MR image will be reconstructed with the data and the model we previously load. The device can be variated between the cpu and gpu. An Output will be generated which indicates the progress of the process.

In [None]:
images = inference(data, model, device=device)

# Compare to CG SENSE

## Set up and initialize CG SENSE reconstruction algorithm

In [None]:
dc_block = DataConsistency(shape, device=device, axes=(1, 2))

In [None]:
print(
    data["images_regridded"].shape,
    data["kspace_trajectory"].shape,
    data["sensitivity_maps"].shape,
)

In [None]:
dc_block.init(
    data["images_regridded"],
    data["kspace_trajectory"],
    sensitivity_maps=data["sensitivity_maps"],
)

## Image reconstruction with CG SENSE

In [None]:
with torch.no_grad():
    cg_sense = dc_block(data["images_regridded"])

# Display results

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(np.abs(cg_sense[:, :, 0, iTI, iTE]), cmap="gray", vmax=vmax)
plt.title("CG-SENSE Recon")
plt.subplot(1, 2, 2)
plt.imshow(np.abs(images[:, :, 0, iTI, iTE]), cmap="gray", vmax=vmax)
plt.title("DL-QRAGE Recon")
plt.show()

In [None]:
images[0]

In [None]:
resnet = ResNet(contrasts=4, dim=3)
image = resnet(images)
plt.imshow(np.abs((image[:, :, 0, iTI, iTE]).detach()), cmap="gray", vmax=vmax)