In [25]:
%matplotlib notebook
import os, sys
import logging
import random
import shutil
import time
import argparse
import numpy as np
import sigpy.plot as pl
import torch
import torchvision
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils.flare_utils import torch2np
import h5py
from utils.resnet2p1d import generate_model
# import matplotlib
# matplotlib.use('TkAgg')
import sigpy as sp
# import custom libraries
from utils import transforms as T
from utils import subsample as ss
from utils import complex_utils as cplx

# import custom classes
from utils.datasets import SliceData
from utils.subsample import VDktMaskFunc
from models.unrolled3D.unrolled3D_MoDL import UnrolledModel as UnrolledModelM
# UnrolledModelM = UnrolledModel
from models.unrolled3D.unrolled3D import UnrolledModel
os.environ["CUDA_VISIBLE_DEVICES"]="5"
device = torch.device('cuda:0')
# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [500]:
slice = 5
da = h5py.File("/home/kewang/cardiac_cine/validate/Exam2200_Series5_Phases20.h5", 'r')
kspace = np.array(da['kspace'])[slice,...]
maps = np.array(da['maps'])[slice,...]
target = np.array(da['target'] )[slice,...]

In [501]:
pl.ImagePlot(target)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f22c5afba10>

In [502]:
dev_mask = VDktMaskFunc([10,15])
maps = maps[...,0][...,None]
target = target[...,0][...,None]
# Convert everything from numpy arrays to tensors
kspace = cplx.to_tensor(kspace).unsqueeze(0)
maps   = cplx.to_tensor(maps).unsqueeze(0)
target = cplx.to_tensor(target).unsqueeze(0)
norm = torch.sqrt(torch.mean(cplx.abs(target)**2))
seed = None
# kspace, target = self.augment(kspace, target, seed)
#         pl.ImagePlot(kspace)
#         pl.ImagePlot(target)
# Undersample k-space data
masked_kspace, mask = ss.subsample(kspace, dev_mask, seed)

In [503]:
A = T.SenseModel(maps)

# Compute normalization factor (based on 95% max signal level in view-shared dataa)
averaged_kspace = T.time_average(masked_kspace, dim=3)
image = A(averaged_kspace, adjoint=True)
magnitude_vals = cplx.abs(image).reshape(-1)
k = int(round(0.05 * magnitude_vals.numel()))
scale = torch.min(torch.topk(magnitude_vals, k).values)

# Normalize k-space and target images
masked_kspace /= scale
target /= scale
mean = torch.tensor([0.0], dtype=torch.float32)
std = scale


In [504]:
mask.shape

torch.Size([1, 64, 180, 20, 1, 1])

In [505]:
masked_kspace = masked_kspace
maps = maps
target = target

In [506]:
print(mask[0,30,...,0,0].numpy().sum()/(180*20))

0.07611111111111112


In [507]:
# Load the checkpoint:
checkpoint_file = "/home/kewang/cardiac_cine/summary/train-3D_5steps_2resblocks_64features_MoDLflag0_CGsteps_8date_6084_ufloss0/best_model.pt"
checkpoint = torch.load(checkpoint_file,map_location=device)
args = checkpoint['args']
# args.device_num='-1'
# args.device='cpu'
model = UnrolledModel(args).to(device)
model.load_state_dict(checkpoint['model'])
print(checkpoint['epoch'])

254


In [534]:
checkpoint_file_modl = "/home/kewang/cardiac_cine/summary/train-3D_5steps_2resblocks_64features_MoDLflag0_CGsteps_8date_6082/best_model.pt"
checkpoint_modl = torch.load(checkpoint_file_modl,map_location=device)
args_modl = checkpoint_modl['args']
# args_modl.device_num='-1'
# args_modl.device='cpu'
model_modl = UnrolledModelM(args_modl).to(device)
model_modl.load_state_dict(checkpoint_modl['model'])
print(checkpoint_modl['epoch'])

No shared weights
140


In [521]:
checkpoint_file_modl_ufloss = "/home/kewang/cardiac_cine/summary/train-3D_5steps_2resblocks_64features_MoDLflag0_CGsteps_8date_6082_ufloss/best_model.pt"
checkpoint_modl_ufloss = torch.load(checkpoint_file_modl_ufloss,map_location=device)
args_modl_ufloss = checkpoint_modl_ufloss['args']
# args_modl_ufloss.device_num='-1'
# args_modl_ufloss.device='cpu'
model_modl_ufloss = UnrolledModelM(args_modl_ufloss).to(device)
model_modl_ufloss.load_state_dict(checkpoint_modl_ufloss['model'])
print(checkpoint_modl_ufloss['epoch'])

No shared weights
140


In [522]:
input_torch = masked_kspace.to(device)
maps_torch = maps.to(device)
target_torch = target.to(device)

In [523]:
with torch.no_grad():
#     start = time.time()
    output = model(input_torch, maps_torch, init_image=None)
#     print(time.time()-start)
    output_modl = model_modl(input_torch, maps_torch, init_image=None)    

In [524]:
ti = []
with torch.no_grad():

    start = time.time()
    output = model_modl(input_torch, maps_torch, init_image=None)
    end = time.time()-start
    print(end)
    ti.append(end)
#     output_modl = model_modl(input_torch, maps_torch, init_image=None)    

0.3062171936035156


In [525]:
with torch.no_grad():
    output_modl_ufloss = model_modl_ufloss(input_torch, maps_torch, init_image=None)    

In [526]:
output_permute = output.squeeze(0).permute(3,4,2,1,0)

In [527]:
A = T.SenseModel(maps_torch)
zf_image = A(input_torch,adjoint=True)

In [528]:
output_np = torch2np(output)
output_modl_ufloss_np = torch2np(output_modl_ufloss)
output_modl_np = torch2np(output_modl)
output_gt = torch2np(target)
output_zf = torch2np(zf_image)
output_all = np.concatenate((output_gt,output_modl_np,output_modl_ufloss_np,output_np))

In [529]:
def compute_metrics(output,target_torch):
    cplx_error = cplx.abs(output - target_torch)
    cplx_l1 = torch.mean(cplx_error)
    cplx_l2 = torch.sqrt(torch.mean(cplx_error**2))
    cplx_psnr = 20 * torch.log10(scale / cplx_l2)
    return cplx_l1,cplx_l2,cplx_psnr

In [532]:
compute_metrics(output,target_torch)

(tensor(0.0483, device='cuda:0'),
 tensor(0.0611, device='cuda:0'),
 tensor(180.8729))

In [533]:
compute_metrics(output_modl,target_torch)

(tensor(0.0483, device='cuda:0'),
 tensor(0.0611, device='cuda:0'),
 tensor(180.8729))

In [519]:
compute_metrics(output_modl_ufloss,target_torch)

(tensor(0.0541, device='cuda:0'),
 tensor(0.0682, device='cuda:0'),
 tensor(179.9220))

In [520]:
pl.ImagePlot(output_all,interpolation="lanczos")

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f22cc47ddd0>

In [55]:
input_torch.shape

tensor([[[[[-0.1057, -0.1347]],

          [[-0.1067, -0.1170]],

          [[-0.1040, -0.1251]],

          ...,

          [[-0.0994, -0.1283]],

          [[-0.1158, -0.1394]],

          [[-0.1084, -0.1458]]],


         [[[-0.0916, -0.1445]],

          [[-0.1020, -0.1759]],

          [[-0.1086, -0.1519]],

          ...,

          [[-0.1148, -0.1750]],

          [[-0.1053, -0.1544]],

          [[-0.0951, -0.1587]]],


         [[[-0.1335,  0.0226]],

          [[-0.1100,  0.0152]],

          [[-0.1320,  0.0156]],

          ...,

          [[-0.1168,  0.0008]],

          [[-0.1077,  0.0238]],

          [[-0.1105,  0.0315]]],


         ...,


         [[[-0.0780, -0.0990]],

          [[-0.0523, -0.0996]],

          [[-0.0854, -0.1068]],

          ...,

          [[-0.0732, -0.0930]],

          [[-0.0608, -0.0933]],

          [[-0.0683, -0.1006]]],


         [[[-0.0600,  0.0312]],

          [[-0.0536,  0.0048]],

          [[-0.0539,  0.0095]],

          ...,

     

In [51]:
args

Namespace(accelerations=[10, 15], batch_size=1, checkpoint=None, circular_pad=True, data_parallel=False, data_path='/home/kewang/cardiac_cine', device='cuda', device_num='2', drop_prob=0.0, exp_dir='/home/kewang/cardiac_cine/summary/train-3D_5steps_2resblocks_64features', fix_step_size=True, kernel_size=3, lr=0.0001, lr_gamma=0.5, lr_step_size=500, modl_flag=False, modl_lamda=0.05, num_cg_steps=6, num_emaps=1, num_epochs=2000, num_features=64, num_grad_steps=5, num_resblocks=2, patch_size=64, report_interval=20, resume=False, sample_rate=1.0, seed=42, share_weights=False, slwin_init=True, weight_decay=0.0)

In [44]:
pl.ImagePlot(image)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f388f405050>