In [1]:
from __future__ import print_function, division
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import data as dt
import model_3 as md
import copy
from pytorch_ssim import ssim
from torch.utils.data import DataLoader

In [2]:
# print gpu
torch.cuda.set_device(3)
currentDevice = torch.cuda.current_device()
print("Current GPU: " + str(currentDevice))
print(str(torch.cuda.device_count()))
print(str(torch.cuda.get_device_capability(currentDevice)))
print(torch.__version__)

Current GPU: 3
8
(6, 1)
1.0.0


In [3]:
USE_GPU = 1
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda:3')
else:
    device = "cpu"
print(device)

cuda:3


In [None]:
# generate csv file, run only for the first time
# dt.generate_csv()

In [4]:
csvFilePath = dt.get_csv_path()
transformed_dataset = dt.HE_SHG_Dataset(csv_file=csvFilePath,
                                               transform=dt.Compose([                                              
                                               dt.Rescale(96),
                                               dt.Normalize(),
                                               dt.ToTensor()
                                           ]))
# TODO: change the normalization parameters

In [5]:
# batchsize 32->16
dataloader = DataLoader(transformed_dataset, batch_size=32,
                        shuffle=True, num_workers=0)

In [None]:
# TODO: insert back mean and variance to plot the image appropriately
dt.show_patch(dataloader) 

In [6]:
print('===> Building model')
model = md.Net().to(device)
criterionMSE = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)


===> Building model


In [7]:
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(dataloader):
        input, target = batch['input'].to(device), batch['output'].to(device)

        optimizer.zero_grad()
        output = model(input)
        targetf = target.float()
        targetf = targetf[:, None]
        
        lossMSE = criterionMSE(output, targetf)      
        lossSSIM = -(ssim(output, targetf)-1)
        
        p = 0.25
        loss = p*lossMSE + (1-p)*lossSSIM
        combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
        
        epoch_loss = epoch_loss + combineLoss
        loss.backward()
        optimizer.step()
    
        if iteration%50 == 0:
            print("lossMSE: " + str(lossMSE.item()) +
                  " " + "lossSSIM: " + str(lossSSIM.item()))
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(dataloader), loss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(dataloader)))


In [None]:
# def test():
#     avg_psnr = 0
#     with torch.no_grad():
#         for batch in testing_data_loader:
#             input, target = batch[0].to(device), batch[1].to(device)

#             prediction = model(input)
#             mse = criterion(prediction, target)
#             psnr = 10 * log10(1 / mse.item())
#             avg_psnr += psnr
#     print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))


In [None]:
# def checkpoint(epoch):
#     model_out_path = "model_epoch_{}.pth".format(epoch)
#     torch.save(model, model_out_path)
#     print("Checkpoint saved to {}".format(model_out_path))

In [8]:
for epoch in range(1, 10 + 1):
    train(epoch)
#     test()
#     checkpoint(epoch)

lossMSE: 0.2147173136472702 lossSSIM: 0.9475566148757935
===> Epoch[1](0/25343): Loss: 0.7643
lossMSE: 0.19983629882335663 lossSSIM: 0.9404192566871643
===> Epoch[1](50/25343): Loss: 0.7553
lossMSE: 0.1958138346672058 lossSSIM: 0.9456815719604492
===> Epoch[1](100/25343): Loss: 0.7582
lossMSE: 0.1781737506389618 lossSSIM: 0.9440086483955383
===> Epoch[1](150/25343): Loss: 0.7525
lossMSE: 0.17502453923225403 lossSSIM: 0.9451726078987122
===> Epoch[1](200/25343): Loss: 0.7526
lossMSE: 0.16243106126785278 lossSSIM: 0.9438861012458801
===> Epoch[1](250/25343): Loss: 0.7485
lossMSE: 0.15796515345573425 lossSSIM: 0.9444253444671631
===> Epoch[1](300/25343): Loss: 0.7478
lossMSE: 0.13620805740356445 lossSSIM: 0.9203538298606873
===> Epoch[1](350/25343): Loss: 0.7243
lossMSE: 0.1353950947523117 lossSSIM: 0.9382316470146179
===> Epoch[1](400/25343): Loss: 0.7375
lossMSE: 0.12759187817573547 lossSSIM: 0.9376736283302307
===> Epoch[1](450/25343): Loss: 0.7352
lossMSE: 0.11874116957187653 lossSSIM

lossMSE: 0.02981446124613285 lossSSIM: 0.29218024015426636
===> Epoch[1](4200/25343): Loss: 0.2266
lossMSE: 0.029836216941475868 lossSSIM: 0.27972036600112915
===> Epoch[1](4250/25343): Loss: 0.2172
lossMSE: 0.026980243623256683 lossSSIM: 0.330585241317749
===> Epoch[1](4300/25343): Loss: 0.2547
lossMSE: 0.025275476276874542 lossSSIM: 0.30609822273254395
===> Epoch[1](4350/25343): Loss: 0.2359
lossMSE: 0.02903839200735092 lossSSIM: 0.27308303117752075
===> Epoch[1](4400/25343): Loss: 0.2121
lossMSE: 0.028859509155154228 lossSSIM: 0.338173508644104
===> Epoch[1](4450/25343): Loss: 0.2608
lossMSE: 0.026932137086987495 lossSSIM: 0.30359476804733276
===> Epoch[1](4500/25343): Loss: 0.2344
lossMSE: 0.028944168239831924 lossSSIM: 0.3111388683319092
===> Epoch[1](4550/25343): Loss: 0.2406
lossMSE: 0.024187499657273293 lossSSIM: 0.2836829423904419
===> Epoch[1](4600/25343): Loss: 0.2188
lossMSE: 0.028180744498968124 lossSSIM: 0.2792942523956299
===> Epoch[1](4650/25343): Loss: 0.2165
lossMSE: 

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms, utils
def test():
    avg_psnr = 0
    with torch.no_grad():
        for iteration, batch in enumerate(dataloader):
            input, target = batch['input'].to(device), batch['output'].to(device)
                
            prediction = model(input)

            target = target.float()

            outdataloader = {'input':prediction,'output':target}
            
            print(outdataloader['input'].size(), 
                      outdataloader['output'].size())

            plt.figure()
            input_batch, label_batch = outdataloader['input'], outdataloader['output']
            batch_size = 32
            im_size = input_batch.size(2)
            label_batch=label_batch.reshape([batch_size,1,im_size,im_size])
            print(label_batch.size())
            for img in input_batch:
                for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
                    t.mul_(s).add_(m)
                            
            for img in label_batch:
                for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
                    t.mul_(s).add_(m)                           

            grid = utils.make_grid(input_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))
            plt.figure()

            grid = utils.make_grid(label_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))

            plt.axis('off')
            plt.ioff()
            plt.show()
            
            targetf = target[:, None]
            
            lossMSE = criterionMSE(prediction, targetf)      
            lossSSIM = -ssim(prediction, targetf)
        
            p = 0.25
            loss = p*lossMSE + (1-p)*lossSSIM
            combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
#             mse = criterion(prediction, target.float())

            psnr = 10 * torch.log10(1 / loss)
            avg_psnr += psnr
            if iteration == 16:
                break
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(dataloader)))

In [None]:
test()

In [10]:
optimizer = optim.Adam(model.parameters(), lr=0.00001)

In [None]:
# restart
for epoch in range(1, 5 + 1):
    train(epoch)
#     test()
#     checkpoint(epoch)

lossMSE: 0.019913265481591225 lossSSIM: 0.2594919800758362
===> Epoch[1](0/25343): Loss: 0.1996
lossMSE: 0.022007493302226067 lossSSIM: 0.3356238007545471
===> Epoch[1](50/25343): Loss: 0.2572
lossMSE: 0.023773454129695892 lossSSIM: 0.281873881816864
===> Epoch[1](100/25343): Loss: 0.2173
lossMSE: 0.018449855968356133 lossSSIM: 0.30438232421875
===> Epoch[1](150/25343): Loss: 0.2329
lossMSE: 0.01957864873111248 lossSSIM: 0.2799588441848755
===> Epoch[1](200/25343): Loss: 0.2149
lossMSE: 0.018327942118048668 lossSSIM: 0.2819406986236572
===> Epoch[1](250/25343): Loss: 0.2160
lossMSE: 0.022603563964366913 lossSSIM: 0.2953016757965088
===> Epoch[1](300/25343): Loss: 0.2271
lossMSE: 0.022102627903223038 lossSSIM: 0.27602773904800415
===> Epoch[1](350/25343): Loss: 0.2125
lossMSE: 0.024097124114632607 lossSSIM: 0.3189759850502014
===> Epoch[1](400/25343): Loss: 0.2453
lossMSE: 0.021922804415225983 lossSSIM: 0.27214497327804565
===> Epoch[1](450/25343): Loss: 0.2096
lossMSE: 0.02331906929612

lossMSE: 0.0210416316986084 lossSSIM: 0.3372621536254883
===> Epoch[1](4150/25343): Loss: 0.2582
lossMSE: 0.01950400322675705 lossSSIM: 0.3030344843864441
===> Epoch[1](4200/25343): Loss: 0.2322
lossMSE: 0.020205356180667877 lossSSIM: 0.26709550619125366
===> Epoch[1](4250/25343): Loss: 0.2054
lossMSE: 0.019020535051822662 lossSSIM: 0.24213957786560059
===> Epoch[1](4300/25343): Loss: 0.1864
lossMSE: 0.02344992384314537 lossSSIM: 0.3406967520713806
===> Epoch[1](4350/25343): Loss: 0.2614
lossMSE: 0.01730542629957199 lossSSIM: 0.27526193857192993
===> Epoch[1](4400/25343): Loss: 0.2108
lossMSE: 0.01814928837120533 lossSSIM: 0.26835697889328003
===> Epoch[1](4450/25343): Loss: 0.2058
lossMSE: 0.026949230581521988 lossSSIM: 0.35033607482910156
===> Epoch[1](4500/25343): Loss: 0.2695
lossMSE: 0.028714967891573906 lossSSIM: 0.3873043656349182
===> Epoch[1](4550/25343): Loss: 0.2977
lossMSE: 0.02306114323437214 lossSSIM: 0.2892981171607971
===> Epoch[1](4600/25343): Loss: 0.2227
lossMSE: 0.0

lossMSE: 0.022350260987877846 lossSSIM: 0.33031177520751953
===> Epoch[1](8300/25343): Loss: 0.2533
lossMSE: 0.025174152106046677 lossSSIM: 0.25997185707092285
===> Epoch[1](8350/25343): Loss: 0.2013
lossMSE: 0.02550695836544037 lossSSIM: 0.33434903621673584
===> Epoch[1](8400/25343): Loss: 0.2571
lossMSE: 0.022521832957863808 lossSSIM: 0.3447178602218628
===> Epoch[1](8450/25343): Loss: 0.2642
lossMSE: 0.021339386701583862 lossSSIM: 0.2868611812591553
===> Epoch[1](8500/25343): Loss: 0.2205
lossMSE: 0.021064184606075287 lossSSIM: 0.3261330723762512
===> Epoch[1](8550/25343): Loss: 0.2499
lossMSE: 0.018116962164640427 lossSSIM: 0.24760395288467407
===> Epoch[1](8600/25343): Loss: 0.1902
lossMSE: 0.019244495779275894 lossSSIM: 0.32735586166381836
===> Epoch[1](8650/25343): Loss: 0.2503
lossMSE: 0.026032522320747375 lossSSIM: 0.2909386157989502
===> Epoch[1](8700/25343): Loss: 0.2247
lossMSE: 0.026958100497722626 lossSSIM: 0.3818133473396301
===> Epoch[1](8750/25343): Loss: 0.2931
lossMS

lossMSE: 0.020281128585338593 lossSSIM: 0.3193751573562622
===> Epoch[1](12450/25343): Loss: 0.2446
lossMSE: 0.025951402261853218 lossSSIM: 0.3548848628997803
===> Epoch[1](12500/25343): Loss: 0.2727
lossMSE: 0.026637772098183632 lossSSIM: 0.3554859757423401
===> Epoch[1](12550/25343): Loss: 0.2733
lossMSE: 0.023476287722587585 lossSSIM: 0.36641162633895874
===> Epoch[1](12600/25343): Loss: 0.2807
lossMSE: 0.019176581874489784 lossSSIM: 0.33264297246932983
===> Epoch[1](12650/25343): Loss: 0.2543
lossMSE: 0.022242549806833267 lossSSIM: 0.33678507804870605
===> Epoch[1](12700/25343): Loss: 0.2581
lossMSE: 0.019060621038079262 lossSSIM: 0.2588672637939453
===> Epoch[1](12750/25343): Loss: 0.1989
lossMSE: 0.033287741243839264 lossSSIM: 0.3963947296142578
===> Epoch[1](12800/25343): Loss: 0.3056
lossMSE: 0.02123422734439373 lossSSIM: 0.2436990737915039
===> Epoch[1](12850/25343): Loss: 0.1881
lossMSE: 0.026048021391034126 lossSSIM: 0.33928507566452026
===> Epoch[1](12900/25343): Loss: 0.26

lossMSE: 0.017977694049477577 lossSSIM: 0.2172207236289978
===> Epoch[1](16550/25343): Loss: 0.1674
lossMSE: 0.02191002480685711 lossSSIM: 0.3296220898628235
===> Epoch[1](16600/25343): Loss: 0.2527
lossMSE: 0.021988924592733383 lossSSIM: 0.32270359992980957
===> Epoch[1](16650/25343): Loss: 0.2475
lossMSE: 0.01715139113366604 lossSSIM: 0.2887430191040039
===> Epoch[1](16700/25343): Loss: 0.2208
lossMSE: 0.02068514935672283 lossSSIM: 0.298997700214386
===> Epoch[1](16750/25343): Loss: 0.2294
lossMSE: 0.019536729902029037 lossSSIM: 0.23890703916549683
===> Epoch[1](16800/25343): Loss: 0.1841
lossMSE: 0.039935436099767685 lossSSIM: 0.36018359661102295
===> Epoch[1](16850/25343): Loss: 0.2801
lossMSE: 0.023372549563646317 lossSSIM: 0.3380390405654907
===> Epoch[1](16900/25343): Loss: 0.2594
lossMSE: 0.024019844830036163 lossSSIM: 0.3186320662498474
===> Epoch[1](16950/25343): Loss: 0.2450
lossMSE: 0.021324466913938522 lossSSIM: 0.31030982732772827
===> Epoch[1](17000/25343): Loss: 0.2381


lossMSE: 0.017257705330848694 lossSSIM: 0.2542024850845337
===> Epoch[1](20650/25343): Loss: 0.1950
lossMSE: 0.022814802825450897 lossSSIM: 0.2988705635070801
===> Epoch[1](20700/25343): Loss: 0.2299
lossMSE: 0.022337324917316437 lossSSIM: 0.35408228635787964
===> Epoch[1](20750/25343): Loss: 0.2711
lossMSE: 0.018737737089395523 lossSSIM: 0.2747843265533447
===> Epoch[1](20800/25343): Loss: 0.2108
lossMSE: 0.01889728754758835 lossSSIM: 0.24900877475738525
===> Epoch[1](20850/25343): Loss: 0.1915
lossMSE: 0.02094864659011364 lossSSIM: 0.30934178829193115
===> Epoch[1](20900/25343): Loss: 0.2372
lossMSE: 0.019702332094311714 lossSSIM: 0.33069467544555664
===> Epoch[1](20950/25343): Loss: 0.2529
lossMSE: 0.021720150485634804 lossSSIM: 0.29041242599487305
===> Epoch[1](21000/25343): Loss: 0.2232
lossMSE: 0.022235214710235596 lossSSIM: 0.30588585138320923
===> Epoch[1](21050/25343): Loss: 0.2350
lossMSE: 0.019370589405298233 lossSSIM: 0.29644423723220825
===> Epoch[1](21100/25343): Loss: 0.

lossMSE: 0.022138036787509918 lossSSIM: 0.3001708984375
===> Epoch[1](24750/25343): Loss: 0.2307
lossMSE: 0.02235446311533451 lossSSIM: 0.31738269329071045
===> Epoch[1](24800/25343): Loss: 0.2436
lossMSE: 0.022167395800352097 lossSSIM: 0.35138756036758423
===> Epoch[1](24850/25343): Loss: 0.2691
lossMSE: 0.023731492459774017 lossSSIM: 0.3208343982696533
===> Epoch[1](24900/25343): Loss: 0.2466
lossMSE: 0.022115524858236313 lossSSIM: 0.2793368697166443
===> Epoch[1](24950/25343): Loss: 0.2150
lossMSE: 0.020496832206845284 lossSSIM: 0.2840423583984375
===> Epoch[1](25000/25343): Loss: 0.2182
lossMSE: 0.02212759107351303 lossSSIM: 0.30602383613586426
===> Epoch[1](25050/25343): Loss: 0.2350
lossMSE: 0.01947225071489811 lossSSIM: 0.30790871381759644
===> Epoch[1](25100/25343): Loss: 0.2358
lossMSE: 0.017379410564899445 lossSSIM: 0.2501024603843689
===> Epoch[1](25150/25343): Loss: 0.1919
lossMSE: 0.024171598255634308 lossSSIM: 0.32894378900527954
===> Epoch[1](25200/25343): Loss: 0.2528
l