In [3]:
import os
import cv2
import scipy.io as scio

import torch
import ctlib
import torch.nn as nn
from torch.utils.data import DataLoader

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

from model import *
from dataset import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
print(cv2.__version__)
import scipy
print(scipy.__version__)
print(torch.__version__)
import skimage
print(skimage.__version__)

4.6.0
1.10.0
1.13.1
0.19.3


In [15]:
def demo(**kwargs):
    model = LAMA(
            n_iter=kwargs['n_iter'], 
            start_iter=kwargs['start_iter'], 
            n_views=kwargs['n_views'], 
            n_Ifeats=kwargs['n_Ifeats'], 
            n_Sfeats=kwargs['n_Sfeats'],
            n_Iconvs=kwargs['n_Iconvs'],
            n_Sconvs=kwargs['n_Sconvs'],
            Iksize=tuple(kwargs['Iksize']),
            Sksize=tuple(kwargs['Sksize']),
            Ipadding=tuple(kwargs['Ipadding']),
            Spadding=tuple(kwargs['Spadding']),
            views = kwargs['views'],
            dets = kwargs['dets'],
            width = kwargs['width'],
            height = kwargs['height'],
            dImg = kwargs['dImg'],
            dDet = kwargs['dDet'],
            Ang0 = kwargs['Ang0'],
            dAng = kwargs['dAng'],
            s2r = kwargs['s2r'], 
            d2r = kwargs['d2r'], 
            binshift = kwargs['binshift'], 
            scan_type = kwargs['scan_type'], 
            alpha=kwargs['alpha'], 
            beta=kwargs['beta'], 
            mu=kwargs['mu'], 
            nu=kwargs['nu'], 
            lam=kwargs['lam'], 
            eta=kwargs['eta'])
    
    model = nn.DataParallel(model)
    model.to(device)
    
    dataset = kwargs['dataset']
    n_views = kwargs['n_views']
    input_type = kwargs['input_type']
    filename = f"./checkpoints/{dataset[:4]}_{n_views}views_{input_type}/checkpoint.pth.tar"
    checkpoint = torch.load(filename)
    if checkpoint:
        print("=> loaded checkpoint '{}' (iter {}, epoch {})"
                  .format(filename, checkpoint['iter'], checkpoint['epoch']))
    else:
        print("=> no checkpoint found")
    model.load_state_dict(checkpoint['state_dict'])
    
    root = os.path.join('./dataset',dataset)
    file_dir = kwargs['file_dir']
    file_prj_dir = kwargs['file_prj_dir']
    test_loader = DataLoader(dataset=LAMA_loader(
                             root, file_dir, file_prj_dir, n_views, False), 
                             batch_size=1, num_workers=8,shuffle=False)
    mask = generate_mask(0.006641,0.0072)
    mask = torch.FloatTensor(mask[None,None,:,:]).to(device)
    
    if dataset[:4] == 'mayo':
        n_samples = 100
    elif dataset[:4] == 'NBIA':
        n_samples = 40
    PSNR_All = np.zeros((n_samples), dtype=np.float32)
    SSIM_All = np.zeros((n_samples), dtype=np.float32)
    
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            x, x_label, z, file_name = data
            # prepare data
            x = x.to(device)
            x_label = x_label.to(device)
            z = z.to(device)

            x_list, z_list = model(x,z)
            x_out = x_list[-1].clip(0,1) * mask
            x_out = x_out.squeeze().detach().cpu().numpy()
            x_label = x_label.squeeze().detach().cpu().numpy()
            
            z_out = z_list[-1].squeeze().detach().cpu().numpy()
            
            p = psnr(x_out, x_label, data_range=1)
            s = ssim(x_out, x_label, data_range=1)
            PSNR_All[i] = p
            SSIM_All[i] = s
            
            print(file_name[0], "\t psnr: {:.4f}".format(p), "\t ssim: {:.4f}".format(s))
    
    avg_psnr = np.mean(PSNR_All)
    psnr_var = np.var(PSNR_All)
    avg_ssim = np.mean(SSIM_All)
    ssim_var = np.var(SSIM_All)
    print()
    print("avg PSNR:", avg_psnr, "\t std:", psnr_var)
    print("avg SSIM:", avg_ssim, "\t std:", ssim_var)
    print()
    
    return avg_psnr, psnr_var, avg_ssim, ssim_var


In [18]:
# sparse view sinogram number of scanning views
n_views = 64
demo(
     # network hyperparameters
     n_iter=15,start_iter=15,n_views=64,input_type="Ix", n_Ifeats=32,n_Sfeats=32,n_Iconvs=4,n_Sconvs=4,
     Iksize=[3,3],Sksize=[3,15],Ipadding=[1,1],Spadding=[1,7],
     alpha=1e-12,beta=1e-12,mu=1e-12,nu=1e-12,lam=10.,eta=1e-10,
     dataset='mayo_data_low_dose_256',file_dir="Ix_64views",file_prj_dir="Iz_64views",
    
    
     # the followings are Radon Transform parameters
     views = 512, dets = 512, width = 256, height = 256, dImg = 0.006641, dDet = 0.0072, Ang0 = 0, 
     dAng = 0.006134 * 2, s2r = 2.5, d2r = 2.5, binshift = 0, scan_type = 0)

=> loaded checkpoint './checkpoints/mayo_64views_Ix/checkpoint.pth.tar' (iter 13, epoch 12)
data_1927.mat 	 psnr: 43.0847 	 ssim: 0.9842
data_1931.mat 	 psnr: 42.9911 	 ssim: 0.9842
data_1935.mat 	 psnr: 43.6795 	 ssim: 0.9867
data_1939.mat 	 psnr: 43.9105 	 ssim: 0.9865
data_1943.mat 	 psnr: 44.2667 	 ssim: 0.9866
data_1947.mat 	 psnr: 44.8332 	 ssim: 0.9883
data_1951.mat 	 psnr: 44.8905 	 ssim: 0.9878
data_1955.mat 	 psnr: 45.0907 	 ssim: 0.9880
data_1959.mat 	 psnr: 45.0771 	 ssim: 0.9875
data_1963.mat 	 psnr: 44.5686 	 ssim: 0.9862
data_1967.mat 	 psnr: 44.0141 	 ssim: 0.9850
data_1971.mat 	 psnr: 45.4339 	 ssim: 0.9867
data_1975.mat 	 psnr: 45.4589 	 ssim: 0.9860
data_1979.mat 	 psnr: 45.4345 	 ssim: 0.9863
data_1983.mat 	 psnr: 45.8685 	 ssim: 0.9877
data_1987.mat 	 psnr: 45.6672 	 ssim: 0.9871
data_1991.mat 	 psnr: 46.0399 	 ssim: 0.9883
data_1995.mat 	 psnr: 45.9265 	 ssim: 0.9889
data_1999.mat 	 psnr: 46.0142 	 ssim: 0.9883
data_2003.mat 	 psnr: 45.9784 	 ssim: 0.9881
data_200

(44.599056, 1.1225926, 0.98617584, 7.142912e-06)