In [None]:
import torch
from torch.utils.data import DataLoader
from datetime import datetime
import logging
from LFDataset import LFDataset
from LFDatatest import LFDatatest
from Functions import weights_init,SetupSeed,CropLF, MergeLF,ComptPSNR,rgb2ycbcr
from DeviceParameters import to_device
from MainNet_pfe_ver0 import MainNet
import itertools,argparse
from skimage.metrics import structural_similarity
import numpy as np
import scipy.io as scio 
import scipy.misc as scim
import matplotlib.pyplot as plt
from collections import defaultdict
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# Training settings
parser = argparse.ArgumentParser(description="Light Field Compressed Sensing")
parser.add_argument("--learningRate", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--step", type=int, default=1000, help="Learning rate decay every n epochs")
parser.add_argument("--reduce", type=float, default=0.5, help="Learning rate decay")
parser.add_argument("--stageNum", type=int, default=6, help="The number of stages")
parser.add_argument("--sasLayerNum", type=int, default=8, help="The number of stages")
parser.add_argument("--temperature_1", type=float, default=1, help="The number of temperature_1")
parser.add_argument("--temperature_2", type=float, default=1, help="The number of temperature_2")
parser.add_argument("--component_num", type=int, default=4, help="The number of nas component")
parser.add_argument("--batchSize", type=int, default=5, help="Batch size")
parser.add_argument("--sampleNum", type=int, default=55, help="The number of LF in training set")
parser.add_argument("--patchSize", type=int, default=32, help="The size of croped LF patch")
parser.add_argument("--num_cp", type=int, default=1000, help="Number of epoches for saving checkpoint")
parser.add_argument("--measurementNum", type=int, default=2, help="The number of measurements")
parser.add_argument("--angResolution", type=int, default=5, help="The angular resolution of original LF")
parser.add_argument("--channelNum", type=int, default=1, help="The channel number of input LF")
parser.add_argument("--epochNum", type=int, default=10000, help="The number of epoches")
parser.add_argument("--overlap", type=int, default=4, help="The size of croped LF patch")
parser.add_argument("--summaryPath", type=str, default='./', help="Path for saving training log ")
parser.add_argument("--dataName", type=str, default='Synthetic', help="The name of dataset ")
parser.add_argument("--preTrain", type=str, default='./model/***pretrained model***', help="Path for loading pretrained model ")
parser.add_argument("--testPath", type=str, default='path_to/test_LFCA_synthetic_5.mat', help="Path for loading training data ")
parser.add_argument("--dataPath", type=str, default='path_to/train_LFCA_synthetic_5.mat', help="Path for loading training data ")

opt = parser.parse_known_args()[0]

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
log = logging.getLogger()
fh = logging.FileHandler('Training_pfe_{}_{}_{}_{}_pfe.log'.format(opt.dataName, opt.measurementNum, opt.stageNum, opt.sasLayerNum))
log.addHandler(fh)
logging.info(opt)

In [None]:
if __name__ == '__main__':

    SetupSeed(1)
    savePath = './model/lfca_{}_{}_{}_{}_{}_{}-pfe'.format(opt.dataName, opt.measurementNum, opt.stageNum, opt.sasLayerNum, opt.epochNum, opt.learningRate)
    lfDataset = LFDataset(opt)
    dataloader = DataLoader(lfDataset, batch_size=opt.batchSize,shuffle=True)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    model=MainNet(opt)
    model.load_state_dict(torch.load(opt.preTrain)['model'])
    model = model.cuda()
    # total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # log.info("Training parameters: %d" %total_trainable_params)

    criterion = torch.nn.L1Loss() # Loss 
    optimizer = torch.optim.Adam(itertools.chain(model.parameters()), lr=opt.learningRate) #optimizer
    scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr = opt.learningRate,steps_per_epoch=len(dataloader),
                                                  epochs=opt.epochNum,pct_start = 0.2, div_factor = 10, final_div_factor = 10)

    lossLogger = defaultdict(list)
    for epoch in range(opt.epochNum):
        batch = 0
        lossSum = 0
        for _,sample in enumerate(dataloader):
            batch = batch +1
            lf=sample['lf']
            lf = lf.cuda()
            
            estimatedLF=model(lf,epoch)
            loss = criterion(estimatedLF,lf)
            lossSum += loss.item()
            print("Epoch: %d Batch: %d Loss: %.6f" %(epoch,batch,loss.item()))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model._modules['proj_init'].weight.data[model._modules['proj_init'].weight.data<0.0]=0.0
            model._modules['proj_init'].weight.data[model._modules['proj_init'].weight.data>1.0]=1.0
            scheduler.step()     #ONE
    
        if epoch % opt.num_cp == 0:
            model_save_path = join(savePath,"pfe_model_epoch_{}.pth".format(epoch))
            state = {'epoch':epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()}
            torch.save(state,model_save_path)
            print("checkpoint saved to {}".format(model_save_path))
        log.info("Epoch: %d Loss: %.6f" %(epoch,lossSum/len(dataloader)))

        #Record the training loss
        lossLogger['Epoch'].append(epoch)
        lossLogger['Loss'].append(lossSum/len(dataloader))
        lossLogger['Lr'].append(optimizer.state_dict()['param_groups'][0]['lr'])
        #lossLogger['Psnr'].append(avg_psnr)
        plt.figure()
        plt.title('Loss')
        plt.plot(lossLogger['Epoch'],lossLogger['Loss'])
        plt.savefig('Training_{}_{}_{}_{}_{}_{}_pfe.jpg'.format(opt.dataName, opt.measurementNum, opt.stageNum,opt.sasLayerNum, opt.epochNum, opt.learningRate))
        plt.close()