去噪NAS pretrain测试代码

In [None]:
import torch
from torch.utils.data import DataLoader
from datetime import datetime
import logging,argparse
import warnings
from LFDatatest import LFDataset
from Functions import weights_init,SetupSeed,CropLF, MergeLF,ComptPSNR,rgb2ycbcr
import itertools,argparse
from skimage.metrics import structural_similarity
import numpy as np
import scipy.io as scio 
import scipy.misc as scim
import matplotlib.pyplot as plt
from collections import defaultdict
import os
import time
from os.path import join
from MainNet_pfe_pretrain import MainNet
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# Testing settings
parser = argparse.ArgumentParser(description="Light Field Compressed Sensing")
parser.add_argument("--learningRate", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--step", type=int, default=1000, help="Learning rate decay every n epochs")
parser.add_argument("--reduce", type=float, default=0.5, help="Learning rate decay")
parser.add_argument("--stageNum", type=int, default=2, help="The number of stages")
parser.add_argument("--sasLayerNum", type=int, default=6, help="The number of stages")
parser.add_argument("--temperature_1", type=float, default=1, help="The number of temperature_1")
parser.add_argument("--temperature_2", type=float, default=1, help="The number of temperature_2")
parser.add_argument("--component_num", type=int, default=4, help="The number of nas component")
parser.add_argument("--noiselevel", type=int, default=20, help="Noise level 10 20 50")
parser.add_argument("--batchSize", type=int, default=1, help="Batch size")
parser.add_argument("--sampleNum", type=int, default=55, help="The number of LF in training set")
parser.add_argument("--patchSize", type=int, default=32, help="The size of croped LF patch")

parser.add_argument("--angResolution", type=int, default=7, help="The angular resolution of original LF")
parser.add_argument("--channelNum", type=int, default=1, help="The channel number of input LF")
parser.add_argument("--epochNum", type=int, default=11000, help="The number of epoches")
parser.add_argument("--overlap", type=int, default=4, help="The size of croped LF patch")
parser.add_argument("--summaryPath", type=str, default='./', help="Path for saving training log ")
parser.add_argument("--dataName", type=str, default='Synthetic', help="The name of dataset ")
parser.add_argument("--modelPath", type=str, default='./model/*** model path***', help="Path for loading trained model ")
parser.add_argument("--dataPath", type=str, default='/***dataroot***/test_synthetic_noiselevel_10_20_50.mat', help="Path for loading training data ")
parser.add_argument("--savePath", type=str, default='./results/', help="Path for saving results ")
opt = parser.parse_known_args()[0]

warnings.filterwarnings("ignore")
plt.ion()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
log = logging.getLogger()
fh = logging.FileHandler('Testing_original.log')
log.addHandler(fh)

In [None]:
lf_dataset = LFDataset(opt)
dataloader = DataLoader(lf_dataset, batch_size=opt.batchSize,shuffle=False)
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(device)  
model=MainNet(opt)
model.load_state_dict(torch.load(opt.modelPath)['model'])
model.eval()
model.cuda()

In [None]:
with torch.no_grad():
#     SetupSeed(50)
    num = 0
    avg_psnr = 0
    avg_ssim = 0
    for _,sample in enumerate(dataloader):
        num=num+1
        LF=sample['lf']
        noilf=sample['noiself']
        lfName=sample['lfname']
        b,u,v,x,y = LF.shape   
        # Crop the input LF into patches 
        LFStack,coordinate=CropLF(noilf,opt.patchSize, opt.overlap) #[b,n,u,v,c,x,y]
        n=LFStack.shape[1]       
        estiLFStack=torch.zeros(b,n,u,v,opt.patchSize,opt.patchSize)#[b,n,u,v,c,x,y]

        for i in range(LFStack.shape[1]):
            estiLFStack[:,i,:,:,:,:] = model(LFStack[:,i,:,:,:,:].cuda(),opt.epochNum)
        estiLF=MergeLF(estiLFStack,coordinate,opt.overlap,x,y) #[b,u,v,c,x,y]
        b,u,v,xCrop,yCrop=estiLF.shape
        LF=LF[:,:,:, opt.overlap//2:opt.overlap//2+xCrop,opt.overlap//2:opt.overlap//2+yCrop]
        lf_psnr = 0
        lf_ssim = 0
        #evaluation
        for ind_uv in range(u*v):
                lf_psnr += ComptPSNR(estiLF.reshape(b,u*v,xCrop,yCrop)[0,ind_uv].cpu().numpy(),
                                     LF.reshape(b,u*v,xCrop,yCrop)[0,ind_uv].cpu().numpy())  / (u*v)

                lf_ssim += structural_similarity((estiLF.reshape(b,u*v,xCrop,yCrop)[0,ind_uv].cpu().numpy()*255.0).astype(np.uint8),
                                        (LF.reshape(b,u*v,xCrop,yCrop)[0,ind_uv].cpu().numpy()*255.0).astype(np.uint8),gaussian_weights=True,sigma=1.5,use_sample_covariance=False,multichannel=False) / (u*v)
        avg_psnr += lf_psnr / len(dataloader)           
        avg_ssim += lf_ssim / len(dataloader)
        log.info('Index: %d  Scene: %s  PSNR: %.2f  SSIM: %.3f'%(num,lfName[0],lf_psnr,lf_ssim))
        #save reconstructed LF
        scio.savemat(os.path.join(opt.savePath,lfName[0]+'.mat'),
                     {'lf_recons':torch.squeeze(estiLF).numpy()})
    log.info('Average PSNR: %.2f  SSIM: %.3f '%(avg_psnr,avg_ssim))  