In [1]:
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import glob
from model import RED_CNN
from CTImage import CTImage
from tqdm.auto import tqdm
from utils import cal_SSIM,cal_PSNR,SSIM
from thop import profile


        

In [21]:

device = ('cuda' if torch.cuda.is_available() else 'cpu')
model = RED_CNN().to(device)
model_path = 'saved_file/ssim_model2_100.pth'
model.load_state_dict(torch.load(model_path))
ct_dataset = CTImage(patch_n=None,patch_size=None,root_dir='../data/TestSet/')


def test(testset,batch_size=16):
    dataloader = DataLoader(testset,batch_size=16,shuffle=False)
    total_origin_PSNR = 0
    total_PSNR = 0
    total_origin_SSIM = 0
    total_SSIM = 0
    cnt = 0
    with torch.no_grad():
        for _,batch in tqdm(enumerate(dataloader),total=len(dataloader)):
            targets = batch[1].to(device)
            inputs = batch[0].to(device)
            outputs = model(inputs).detach()
            for i in range(targets.shape[0]):
                total_origin_PSNR += cal_PSNR(inputs[i],targets[i])
                total_origin_SSIM += cal_SSIM(inputs[i],targets[i])
                total_PSNR += cal_PSNR(outputs[i],targets[i])
                total_SSIM += cal_SSIM(outputs[i],targets[i])
            
            cnt += targets.shape[0]
    avg_origin_PSNR = total_origin_PSNR / cnt
    avg_origin_SSIM = total_origin_SSIM / cnt
    avg_PSNR = total_PSNR / cnt
    avg_SSIM = total_SSIM / cnt
    return avg_PSNR,avg_SSIM,avg_origin_PSNR,avg_origin_SSIM


In [22]:
test(ct_dataset)

  0%|          | 0/33 [00:00<?, ?it/s]

(40.67762046712433,
 tensor(0.9979, device='cuda:0'),
 36.196819067907875,
 tensor(0.9941, device='cuda:0'))

In [20]:
input = torch.randn((1,1,256,256)).to(device)
macs, params = profile(model, inputs=(input, ))
print("GFLOPs:{},#Parameters:{}".format(int(macs)>>29 ,params))# flops = macs * 2 G = 2^30

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
GFLOPs:205,#Parameters:1848865.0


In [6]:

from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(model, (1, 256, 256), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

RED_CNN(
  1.85 M, 100.000% Params, 110.19 GMac, 100.000% MACs, 
  (conv1): Conv2d(2.5 k, 0.135% Params, 158.51 MMac, 0.144% MACs, 1, 96, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(230.5 k, 12.467% Params, 14.18 GMac, 12.865% MACs, 96, 96, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(230.5 k, 12.467% Params, 13.72 GMac, 12.453% MACs, 96, 96, kernel_size=(5, 5), stride=(1, 1))
  (conv4): Conv2d(230.5 k, 12.467% Params, 13.28 GMac, 12.048% MACs, 96, 96, kernel_size=(5, 5), stride=(1, 1))
  (conv5): Conv2d(230.5 k, 12.467% Params, 12.84 GMac, 11.650% MACs, 96, 96, kernel_size=(5, 5), stride=(1, 1))
  (tconv1): ConvTranspose2d(230.5 k, 12.467% Params, 13.28 GMac, 12.048% MACs, 96, 96, kernel_size=(5, 5), stride=(1, 1))
  (tconv2): ConvTranspose2d(230.5 k, 12.467% Params, 13.72 GMac, 12.453% MACs, 96, 96, kernel_size=(5, 5), stride=(1, 1))
  (tconv3): ConvTranspose2d(230.5 k, 12.467% Params, 14.18 GMac, 12.865% MACs, 96, 96, kernel_size=(5, 5), stride=(1, 1))
  (tconv4): C