In [1]:
from __future__ import print_function, division
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import data as dt
import model as md
import copy
from pytorch_ssim import ssim
from torch.utils.data import DataLoader

In [2]:
# print gpu
torch.cuda.set_device(3)
currentDevice = torch.cuda.current_device()
print("Current GPU: " + str(currentDevice))
print(str(torch.cuda.device_count()))
print(str(torch.cuda.get_device_capability(currentDevice)))
print(torch.__version__)

Current GPU: 3
8
(6, 1)
1.0.0


In [3]:
USE_GPU = 1
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda:3')
else:
    device = "cpu"
print(device)

cuda:3


In [None]:
# generate csv file, run only for the first time
# dt.generate_csv()

In [4]:
csvFilePath = dt.get_csv_path()
transformed_dataset = dt.HE_SHG_Dataset(csv_file=csvFilePath,
                                               transform=dt.Compose([                                              
                                               dt.Rescale(96),
                                               dt.ToTensor(),
                                               dt.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])                                           
                                           ]))
# TODO: change the normalization parameters

In [5]:
# batchsize 32->16
dataloader = DataLoader(transformed_dataset, batch_size=32,
                        shuffle=True, num_workers=0)

In [None]:
# TODO: insert back mean and variance to plot the image appropriately
dt.show_patch(dataloader) 

In [6]:
print('===> Building model')
model = md.Net().to(device)
criterionMSE = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.01)


===> Building model


In [7]:
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(dataloader):
        input, target = batch['input'].to(device), batch['output'].to(device)

        optimizer.zero_grad()
        output = model(input)
        targetf = target.float()
        targetf = targetf[:, None]
        
        lossMSE = criterionMSE(output, targetf)      
        lossSSIM = -(ssim(output, targetf)-1)
        
        p = 0.25
        loss = p*lossMSE + (1-p)*lossSSIM
        combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
        
        epoch_loss = epoch_loss + combineLoss
        loss.backward()
        optimizer.step()
    
        if iteration%50 == 0:
            print("lossMSE: " + str(lossMSE.item()) +
                  " " + "lossSSIM: " + str(lossSSIM.item()))
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(dataloader), loss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(dataloader)))


In [None]:
# def test():
#     avg_psnr = 0
#     with torch.no_grad():
#         for batch in testing_data_loader:
#             input, target = batch[0].to(device), batch[1].to(device)

#             prediction = model(input)
#             mse = criterion(prediction, target)
#             psnr = 10 * log10(1 / mse.item())
#             avg_psnr += psnr
#     print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))


In [None]:
# def checkpoint(epoch):
#     model_out_path = "model_epoch_{}.pth".format(epoch)
#     torch.save(model, model_out_path)
#     print("Checkpoint saved to {}".format(model_out_path))

In [None]:
for epoch in range(1, 10 + 1):
    train(epoch)
#     test()
#     checkpoint(epoch)

lossMSE: 2.2096481323242188 lossSSIM: 1.0789344310760498
===> Epoch[1](0/25343): Loss: 1.3616
lossMSE: 2.1858043670654297 lossSSIM: 1.0360506772994995
===> Epoch[1](50/25343): Loss: 1.3235
lossMSE: 2.1437644958496094 lossSSIM: 1.034879207611084
===> Epoch[1](100/25343): Loss: 1.3121
lossMSE: 2.0976638793945312 lossSSIM: 1.0340828895568848
===> Epoch[1](150/25343): Loss: 1.3000
lossMSE: 2.1018261909484863 lossSSIM: 1.0319061279296875
===> Epoch[1](200/25343): Loss: 1.2994
lossMSE: 2.0222904682159424 lossSSIM: 1.0291587114334106
===> Epoch[1](250/25343): Loss: 1.2774
lossMSE: 2.0070016384124756 lossSSIM: 1.0280711650848389
===> Epoch[1](300/25343): Loss: 1.2728
lossMSE: 1.9654513597488403 lossSSIM: 1.0300029516220093
===> Epoch[1](350/25343): Loss: 1.2639
lossMSE: 1.9446958303451538 lossSSIM: 1.0274171829223633
===> Epoch[1](400/25343): Loss: 1.2567
lossMSE: 1.7757505178451538 lossSSIM: 1.0261132717132568
===> Epoch[1](450/25343): Loss: 1.2135
lossMSE: 1.8746812343597412 lossSSIM: 1.0288

lossMSE: 1.1003761291503906 lossSSIM: 1.0047268867492676
===> Epoch[1](4300/25343): Loss: 1.0286
lossMSE: 1.049142599105835 lossSSIM: 1.004579782485962
===> Epoch[1](4350/25343): Loss: 1.0157
lossMSE: 1.0749120712280273 lossSSIM: 1.004716157913208
===> Epoch[1](4400/25343): Loss: 1.0223
lossMSE: 1.0845999717712402 lossSSIM: 1.0047023296356201
===> Epoch[1](4450/25343): Loss: 1.0247
lossMSE: 1.1119210720062256 lossSSIM: 1.004279613494873
===> Epoch[1](4500/25343): Loss: 1.0312
lossMSE: 1.0710504055023193 lossSSIM: 1.0045636892318726
===> Epoch[1](4550/25343): Loss: 1.0212
lossMSE: 1.100809097290039 lossSSIM: 1.0044628381729126
===> Epoch[1](4600/25343): Loss: 1.0285
lossMSE: 1.05734384059906 lossSSIM: 1.0042668581008911
===> Epoch[1](4650/25343): Loss: 1.0175
lossMSE: 1.1011220216751099 lossSSIM: 1.0044922828674316
===> Epoch[1](4700/25343): Loss: 1.0286
lossMSE: 1.0729526281356812 lossSSIM: 1.0040820837020874
===> Epoch[1](4750/25343): Loss: 1.0213
lossMSE: 1.1010494232177734 lossSSIM:

lossMSE: 1.002738118171692 lossSSIM: 1.0019875764846802
===> Epoch[1](8550/25343): Loss: 1.0022
lossMSE: 1.0046160221099854 lossSSIM: 1.0021626949310303
===> Epoch[1](8600/25343): Loss: 1.0028
lossMSE: 1.0244985818862915 lossSSIM: 1.0020920038223267
===> Epoch[1](8650/25343): Loss: 1.0077
lossMSE: 1.01283860206604 lossSSIM: 1.002210259437561
===> Epoch[1](8700/25343): Loss: 1.0049
lossMSE: 1.0344271659851074 lossSSIM: 1.002276062965393
===> Epoch[1](8750/25343): Loss: 1.0103
lossMSE: 1.0223729610443115 lossSSIM: 1.0020748376846313
===> Epoch[1](8800/25343): Loss: 1.0071
lossMSE: 1.043779969215393 lossSSIM: 1.0021942853927612
===> Epoch[1](8850/25343): Loss: 1.0126
lossMSE: 1.0098631381988525 lossSSIM: 1.0022526979446411
===> Epoch[1](8900/25343): Loss: 1.0042
lossMSE: 0.9983373284339905 lossSSIM: 1.002130150794983
===> Epoch[1](8950/25343): Loss: 1.0012
lossMSE: 0.9917740225791931 lossSSIM: 1.0021402835845947
===> Epoch[1](9000/25343): Loss: 0.9995
lossMSE: 1.0322937965393066 lossSSIM:

lossMSE: 1.0020146369934082 lossSSIM: 1.0019148588180542
===> Epoch[1](12800/25343): Loss: 1.0019
lossMSE: 0.9708382487297058 lossSSIM: 1.0018198490142822
===> Epoch[1](12850/25343): Loss: 0.9941
lossMSE: 1.034166932106018 lossSSIM: 1.0018230676651
===> Epoch[1](12900/25343): Loss: 1.0099
lossMSE: 0.9908655881881714 lossSSIM: 1.0020205974578857
===> Epoch[1](12950/25343): Loss: 0.9992
lossMSE: 0.9994843602180481 lossSSIM: 1.002109408378601
===> Epoch[1](13000/25343): Loss: 1.0015
lossMSE: 1.0204893350601196 lossSSIM: 1.0019664764404297
===> Epoch[1](13050/25343): Loss: 1.0066
lossMSE: 0.968281626701355 lossSSIM: 1.0016182661056519
===> Epoch[1](13100/25343): Loss: 0.9933
lossMSE: 1.0210479497909546 lossSSIM: 1.0020874738693237
===> Epoch[1](13150/25343): Loss: 1.0068
lossMSE: 0.9628094434738159 lossSSIM: 1.0018099546432495
===> Epoch[1](13200/25343): Loss: 0.9921
lossMSE: 0.9972183704376221 lossSSIM: 1.0019254684448242
===> Epoch[1](13250/25343): Loss: 1.0007
lossMSE: 1.002301812171936

lossMSE: 0.9998948574066162 lossSSIM: 1.0018500089645386
===> Epoch[1](17000/25343): Loss: 1.0014
lossMSE: 0.9753267765045166 lossSSIM: 1.001810073852539
===> Epoch[1](17050/25343): Loss: 0.9952
lossMSE: 1.0046236515045166 lossSSIM: 1.0017551183700562
===> Epoch[1](17100/25343): Loss: 1.0025
lossMSE: 0.9797242283821106 lossSSIM: 1.001834750175476
===> Epoch[1](17150/25343): Loss: 0.9963
lossMSE: 1.0029561519622803 lossSSIM: 1.0020358562469482
===> Epoch[1](17200/25343): Loss: 1.0023
lossMSE: 0.9926798343658447 lossSSIM: 1.0018537044525146
===> Epoch[1](17250/25343): Loss: 0.9996
lossMSE: 1.0407333374023438 lossSSIM: 1.0018364191055298
===> Epoch[1](17300/25343): Loss: 1.0116
lossMSE: 0.9838731288909912 lossSSIM: 1.0016969442367554
===> Epoch[1](17350/25343): Loss: 0.9972
lossMSE: 0.9998170137405396 lossSSIM: 1.0017813444137573
===> Epoch[1](17400/25343): Loss: 1.0013
lossMSE: 1.0416845083236694 lossSSIM: 1.001792550086975
===> Epoch[1](17450/25343): Loss: 1.0118
lossMSE: 1.023398518562

lossMSE: 1.0363514423370361 lossSSIM: 1.0017880201339722
===> Epoch[1](21250/25343): Loss: 1.0104
lossMSE: 1.04443359375 lossSSIM: 1.0017699003219604
===> Epoch[1](21300/25343): Loss: 1.0124
lossMSE: 0.9775893092155457 lossSSIM: 1.0017437934875488
===> Epoch[1](21350/25343): Loss: 0.9957
lossMSE: 1.0334022045135498 lossSSIM: 1.0017846822738647
===> Epoch[1](21400/25343): Loss: 1.0097
lossMSE: 0.9746768474578857 lossSSIM: 1.0016885995864868
===> Epoch[1](21450/25343): Loss: 0.9949
lossMSE: 1.0077310800552368 lossSSIM: 1.001705288887024
===> Epoch[1](21500/25343): Loss: 1.0032
lossMSE: 0.9907313585281372 lossSSIM: 1.0019330978393555
===> Epoch[1](21550/25343): Loss: 0.9991
lossMSE: 0.9901340007781982 lossSSIM: 1.0018463134765625
===> Epoch[1](21600/25343): Loss: 0.9989
lossMSE: 1.006934642791748 lossSSIM: 1.0019135475158691
===> Epoch[1](21650/25343): Loss: 1.0032
lossMSE: 1.0444644689559937 lossSSIM: 1.0018634796142578
===> Epoch[1](21700/25343): Loss: 1.0125
lossMSE: 1.0254873037338257

lossMSE: 1.023979663848877 lossSSIM: 1.0018503665924072
===> Epoch[2](100/25343): Loss: 1.0074
lossMSE: 0.9901686906814575 lossSSIM: 1.0018247365951538
===> Epoch[2](150/25343): Loss: 0.9989
lossMSE: 1.0147852897644043 lossSSIM: 1.0016852617263794
===> Epoch[2](200/25343): Loss: 1.0050
lossMSE: 0.9627934694290161 lossSSIM: 1.0019317865371704
===> Epoch[2](250/25343): Loss: 0.9921
lossMSE: 1.0242434740066528 lossSSIM: 1.0018137693405151
===> Epoch[2](300/25343): Loss: 1.0074
lossMSE: 0.9930347800254822 lossSSIM: 1.0017119646072388
===> Epoch[2](350/25343): Loss: 0.9995
lossMSE: 1.0170040130615234 lossSSIM: 1.0019183158874512
===> Epoch[2](400/25343): Loss: 1.0057
lossMSE: 0.9936351776123047 lossSSIM: 1.0017811059951782
===> Epoch[2](450/25343): Loss: 0.9997
lossMSE: 1.0126032829284668 lossSSIM: 1.0018010139465332
===> Epoch[2](500/25343): Loss: 1.0045
lossMSE: 1.0150949954986572 lossSSIM: 1.0018960237503052
===> Epoch[2](550/25343): Loss: 1.0052
lossMSE: 0.9949722290039062 lossSSIM: 1.0

lossMSE: 0.9729194641113281 lossSSIM: 1.0018764734268188
===> Epoch[2](4400/25343): Loss: 0.9946
lossMSE: 1.0153673887252808 lossSSIM: 1.001965045928955
===> Epoch[2](4450/25343): Loss: 1.0053
lossMSE: 0.9827901124954224 lossSSIM: 1.0018361806869507
===> Epoch[2](4500/25343): Loss: 0.9971
lossMSE: 0.9953075051307678 lossSSIM: 1.0017755031585693
===> Epoch[2](4550/25343): Loss: 1.0002
lossMSE: 1.0232290029525757 lossSSIM: 1.0018121004104614
===> Epoch[2](4600/25343): Loss: 1.0072
lossMSE: 1.0127134323120117 lossSSIM: 1.0019078254699707
===> Epoch[2](4650/25343): Loss: 1.0046
lossMSE: 1.0466786623001099 lossSSIM: 1.00188148021698
===> Epoch[2](4700/25343): Loss: 1.0131
lossMSE: 1.0288362503051758 lossSSIM: 1.001717448234558
===> Epoch[2](4750/25343): Loss: 1.0085
lossMSE: 1.0321485996246338 lossSSIM: 1.001753330230713
===> Epoch[2](4800/25343): Loss: 1.0094
lossMSE: 1.0155436992645264 lossSSIM: 1.0019400119781494
===> Epoch[2](4850/25343): Loss: 1.0053
lossMSE: 0.9697619676589966 lossSSI

lossMSE: 1.0195674896240234 lossSSIM: 1.0017253160476685
===> Epoch[2](8650/25343): Loss: 1.0062
lossMSE: 1.0375943183898926 lossSSIM: 1.0018439292907715
===> Epoch[2](8700/25343): Loss: 1.0108
lossMSE: 1.061147689819336 lossSSIM: 1.0019097328186035
===> Epoch[2](8750/25343): Loss: 1.0167
lossMSE: 1.0250369310379028 lossSSIM: 1.001863718032837
===> Epoch[2](8800/25343): Loss: 1.0077
lossMSE: 0.9946648478507996 lossSSIM: 1.0019103288650513
===> Epoch[2](8850/25343): Loss: 1.0001
lossMSE: 0.9938962459564209 lossSSIM: 1.0017366409301758
===> Epoch[2](8900/25343): Loss: 0.9998
lossMSE: 1.0272047519683838 lossSSIM: 1.0018320083618164
===> Epoch[2](8950/25343): Loss: 1.0082
lossMSE: 1.0041269063949585 lossSSIM: 1.0016603469848633
===> Epoch[2](9000/25343): Loss: 1.0023
lossMSE: 0.9889326095581055 lossSSIM: 1.0017844438552856
===> Epoch[2](9050/25343): Loss: 0.9986
lossMSE: 1.0215458869934082 lossSSIM: 1.0017553567886353
===> Epoch[2](9100/25343): Loss: 1.0067
lossMSE: 0.9920517802238464 loss

lossMSE: 0.9859637022018433 lossSSIM: 1.0017403364181519
===> Epoch[2](12900/25343): Loss: 0.9978
lossMSE: 0.9946760535240173 lossSSIM: 1.0019396543502808
===> Epoch[2](12950/25343): Loss: 1.0001
lossMSE: 0.9951486587524414 lossSSIM: 1.0017727613449097
===> Epoch[2](13000/25343): Loss: 1.0001
lossMSE: 1.0063992738723755 lossSSIM: 1.0018359422683716
===> Epoch[2](13050/25343): Loss: 1.0030
lossMSE: 1.0112528800964355 lossSSIM: 1.0018891096115112
===> Epoch[2](13100/25343): Loss: 1.0042
lossMSE: 1.0108681917190552 lossSSIM: 1.001743197441101
===> Epoch[2](13150/25343): Loss: 1.0040
lossMSE: 1.0006999969482422 lossSSIM: 1.001643419265747
===> Epoch[2](13200/25343): Loss: 1.0014
lossMSE: 1.007895827293396 lossSSIM: 1.0016523599624634
===> Epoch[2](13250/25343): Loss: 1.0032
lossMSE: 1.0005154609680176 lossSSIM: 1.0018715858459473
===> Epoch[2](13300/25343): Loss: 1.0015
lossMSE: 1.0182942152023315 lossSSIM: 1.0019795894622803
===> Epoch[2](13350/25343): Loss: 1.0061
lossMSE: 1.004306554794

In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms, utils
def test():
    avg_psnr = 0
    with torch.no_grad():
        for iteration, batch in enumerate(dataloader):
            input, target = batch['input'].to(device), batch['output'].to(device)
                
            prediction = model(input)

            target = target.float()

            outdataloader = {'input':prediction,'output':target}
            
            print(outdataloader['input'].size(), 
                      outdataloader['output'].size())

            plt.figure()
            input_batch, label_batch = outdataloader['input'], outdataloader['output']
            batch_size = 32
            im_size = input_batch.size(2)
            label_batch=label_batch.reshape([batch_size,1,im_size,im_size])
            print(label_batch.size())
            for img in input_batch:
                for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
                    t.mul_(s).add_(m)
                            
            for img in label_batch:
                for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
                    t.mul_(s).add_(m)                           

            grid = utils.make_grid(input_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))
            plt.figure()

            grid = utils.make_grid(label_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))

            plt.axis('off')
            plt.ioff()
            plt.show()
            
            targetf = target[:, None]
            
            lossMSE = criterionMSE(prediction, targetf)      
            lossSSIM = -ssim(prediction, targetf)
        
            p = 0.25
            loss = p*lossMSE + (1-p)*lossSSIM
            combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
#             mse = criterion(prediction, target.float())

            psnr = 10 * torch.log10(1 / loss)
            avg_psnr += psnr
            if iteration == 16:
                break
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(dataloader)))

In [None]:
test()

In [None]:
# restart
for epoch in range(1, 5 + 1):
    train(epoch)
#     test()
#     checkpoint(epoch)