In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys
import gc
import time
import torch.multiprocessing as mp
import torch.distributed as dist

sys.path.insert(0, "../../src")

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import zarr as z

from juart.dl.model.dc import DataConsistency
from juart.dl.model.resnet import ResNet

from juart.conopt.functional.fourier import (
    fourier_transform_adjoint,
    fourier_transform_forward,
    nonuniform_fourier_transform_adjoint,
)
from juart.conopt.tfs.fourier import nonuniform_transfer_function
from juart.dl.checkpoint.manager import CheckpointManager
from juart.dl.loss.loss import JointLoss
from juart.dl.operation.modules import training, validation
from juart.dl.utils.dist import GradientAccumulator
from juart.dl.model.unrollnet import (
    LookaheadModel,
    UnrolledNet,
)

from juart.vis.interactive import InteractiveFigure3D, InteractiveMultiPlotter3D

In [None]:
cgiter = 10
features = 32
num_resblocks = 10
num_unrollblocks = 10

nX, nY, nZ, nTI, nTE = 128, 128, 128, 1, 1
shape = (nX, nY, nZ, nTI, nTE)
dtype = torch.complex64

store = z.open("/home/jovyan/datasets/num_phantom_128_R1")

C = torch.from_numpy(np.array(store["C"]))
k = torch.from_numpy(np.array(store["k"]))[...,None,None]
d = torch.from_numpy(np.array(store["d"]))[...,None,None]

coilsens = C

In [None]:
generator = torch.Generator()
generator.manual_seed(0)

kspace_mask_source = torch.randint(0, 2, (1, k.shape[1], 1, 1), generator=generator)
kspace_mask_target = 1 - kspace_mask_source
k_masked = k * kspace_mask_source

AHd = nonuniform_fourier_transform_adjoint(k_masked, d, (nX, nY, nZ))
AHd = torch.sum(torch.conj(C[..., None,None]) * AHd, dim=0)

data = [
    {
        "images_regridded": AHd,
        "kspace_trajectory": k,
        "sensitivity_maps": C,
        "kspace_mask_source": kspace_mask_source,
        "kspace_mask_target": kspace_mask_target,
        "kspace_data": d,
    }
]

In [None]:
def train_loop(device):

    model = UnrolledNet(
        shape,
        features=features,
        CG_Iter=cgiter,
        num_unroll_blocks=num_unrollblocks,
        num_res_blocks = num_resblocks,
        activation="ReLU",
        disable_progress_bar=False,
        timing_level=0,
        validation_level=0,
        kernel_size=(3,3,3),
        axes=(1,2,3),
        ResNetCheckpoints = True,
        device=device,
        dtype = dtype
    )
    
    loss_fn = JointLoss(
        shape,
        3,
        weights_kspace_loss=(0.5, 0.5),
        weights_ispace_loss=(0.0, 0.0),
        weights_wavelet_loss=(0.0, 0.0),
        weights_hankel_loss=(0.0, 0.0),
        weights_casorati_loss=(0.0, 0.0),
        normalized_loss=True,
        timing_level=0,
        validation_level=0,
        device=device,
        dtype = dtype
    )

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.0001,
        betas=[0.9, 0.999],
        eps=1.0e-8,
        weight_decay=0.0
    )

    accumulator = GradientAccumulator(
        model,
        accumulation_steps=1,
        max_norm=1.0,
        normalized_gradient=False
    )

    averaged_model = LookaheadModel(
        model,
        alpha=0.5,
        k=5
    )

    image = training(
        [0],
        data,
        model,
        loss_fn,
        optimizer,
        accumulator,
        device=device,
    )

    return image

In [None]:
def DCRES_Block(device):
    
    resnet = ResNet(contrasts=1,
            features=features,
            dim=3,
            num_of_resblocks=num_resblocks,
            device=device,
            kernel_size=(3,3,3),
            ResNetCheckpoints = True)

    dc_block = DataConsistency(
        shape,
        axes = (1,2,3),
        device=device,
        verbose = True,
        niter = cgiter
    )

    dc_block.init(
        data[0]["images_regridded"],
        data[0]["kspace_trajectory"],
        sensitivity_maps=data[0]["sensitivity_maps"],
        kspace_mask = data[0]["kspace_mask_source"]
    )

    image = data[0]["images_regridded"]
    print(image.shape)
    with torch.no_grad():
        for _ in range(0,num_unrollblocks,1):
            image = resnet(image)
            image = dc_block(image)

    return image

In [None]:
dist.init_process_group(
    backend="gloo", init_method="tcp://127.0.0.1:13579", world_size=1, rank=0
)

In [None]:
trn_image = train_loop("cuda:2")
dcres_image = DCRES_Block("cuda:3")

In [None]:
InteractiveMultiPlotter3D([trn_image[...,0,0].cpu().abs().detach().numpy(),
                           dcres_image[...,0,0].cpu().abs().detach().numpy()],
                          title= ["trnimage", "dcresimage"],
                          vmin = 0,
                          vmax = 10).interactive