In [1]:
from __future__ import print_function, division
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import data as dt
import model as md
import copy
from pytorch_ssim import ssim
from torch.utils.data import DataLoader

In [2]:
# print gpu
torch.cuda.set_device(4)
currentDevice = torch.cuda.current_device()
print("Current GPU: " + str(currentDevice))
print(str(torch.cuda.device_count()))
print(str(torch.cuda.get_device_capability(currentDevice)))
print(torch.__version__)

Current GPU: 4
8
(6, 1)
1.0.0


In [3]:
USE_GPU = 1
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda:4')
else:
    device = "cpu"
print(device)

cuda:4


In [None]:
# generate csv file, run only for the first time
# dt.generate_csv()

In [4]:
csvFilePath = dt.get_csv_path()
transformed_dataset = dt.HE_SHG_Dataset(csv_file=csvFilePath,
                                               transform=dt.Compose([                                              
                                               dt.Rescale(96),
                                               dt.ToTensor(),
                                               dt.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])                                           
                                           ]))
# TODO: change the normalization parameters

In [5]:
# batchsize 32->16
dataloader = DataLoader(transformed_dataset, batch_size=100,
                        shuffle=True, num_workers=0)

In [None]:
# TODO: insert back mean and variance to plot the image appropriately
dt.show_patch(dataloader) 

In [6]:
print('===> Building model')
model = md.Net().to(device)
criterionMSE = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.1)


===> Building model


In [7]:
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(dataloader):
        input, target = batch['input'].to(device), batch['output'].to(device)

        optimizer.zero_grad()
        output = model(input)
        targetf = target.float()
        targetf = targetf[:, None]
        
        lossMSE = criterionMSE(output, targetf)      
        lossSSIM = -(ssim(output, targetf)-1)
        
        p = 0.25
        loss = p*lossMSE + (1-p)*lossSSIM
        combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
        
        epoch_loss = epoch_loss + combineLoss
        loss.backward()
        optimizer.step()
    
        if iteration%50 == 0:
            print("lossMSE: " + str(lossMSE.item()) +
                  " " + "lossSSIM: " + str(lossSSIM.item()))
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(dataloader), loss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(dataloader)))


In [None]:
# def test():
#     avg_psnr = 0
#     with torch.no_grad():
#         for batch in testing_data_loader:
#             input, target = batch[0].to(device), batch[1].to(device)

#             prediction = model(input)
#             mse = criterion(prediction, target)
#             psnr = 10 * log10(1 / mse.item())
#             avg_psnr += psnr
#     print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))


In [None]:
# def checkpoint(epoch):
#     model_out_path = "model_epoch_{}.pth".format(epoch)
#     torch.save(model, model_out_path)
#     print("Checkpoint saved to {}".format(model_out_path))

In [None]:
for epoch in range(1, 10 + 1):
    train(epoch)
#     test()
#     checkpoint(epoch)

lossMSE: 2.1877079010009766 lossSSIM: 1.1010581254959106
===> Epoch[1](0/8110): Loss: 1.3727
lossMSE: 1.953629970550537 lossSSIM: 1.0336592197418213
===> Epoch[1](50/8110): Loss: 1.2637
lossMSE: 1.7648597955703735 lossSSIM: 1.0232607126235962
===> Epoch[1](100/8110): Loss: 1.2087
lossMSE: 1.576941967010498 lossSSIM: 1.0178065299987793
===> Epoch[1](150/8110): Loss: 1.1576
lossMSE: 1.4379291534423828 lossSSIM: 1.0191327333450317
===> Epoch[1](200/8110): Loss: 1.1238
lossMSE: 1.3351436853408813 lossSSIM: 1.0160794258117676
===> Epoch[1](250/8110): Loss: 1.0958
lossMSE: 1.2499107122421265 lossSSIM: 1.0138956308364868
===> Epoch[1](300/8110): Loss: 1.0729
lossMSE: 1.1840966939926147 lossSSIM: 1.0122668743133545
===> Epoch[1](350/8110): Loss: 1.0552
lossMSE: 1.1722099781036377 lossSSIM: 1.011806607246399
===> Epoch[1](400/8110): Loss: 1.0519
lossMSE: 1.1165852546691895 lossSSIM: 1.0114742517471313
===> Epoch[1](450/8110): Loss: 1.0378
lossMSE: 1.1007307767868042 lossSSIM: 1.009819746017456


lossMSE: 0.9842938780784607 lossSSIM: 1.0019030570983887
===> Epoch[1](4300/8110): Loss: 0.9975
lossMSE: 0.9838666319847107 lossSSIM: 1.0018354654312134
===> Epoch[1](4350/8110): Loss: 0.9973
lossMSE: 0.9590764045715332 lossSSIM: 1.0018264055252075
===> Epoch[1](4400/8110): Loss: 0.9911
lossMSE: 0.9767499566078186 lossSSIM: 1.0018724203109741
===> Epoch[1](4450/8110): Loss: 0.9956
lossMSE: 0.9704901576042175 lossSSIM: 1.0018130540847778
===> Epoch[1](4500/8110): Loss: 0.9940
lossMSE: 0.9558037519454956 lossSSIM: 1.0018435716629028
===> Epoch[1](4550/8110): Loss: 0.9903
lossMSE: 0.9605242013931274 lossSSIM: 1.0017857551574707
===> Epoch[1](4600/8110): Loss: 0.9915
lossMSE: 0.9739366769790649 lossSSIM: 1.0018137693405151
===> Epoch[1](4650/8110): Loss: 0.9948
lossMSE: 0.9633535146713257 lossSSIM: 1.001857876777649
===> Epoch[1](4700/8110): Loss: 0.9922
lossMSE: 0.9727911949157715 lossSSIM: 1.0018080472946167
===> Epoch[1](4750/8110): Loss: 0.9946
lossMSE: 0.9705262780189514 lossSSIM: 1.0

lossMSE: 0.9641979932785034 lossSSIM: 1.001731514930725
===> Epoch[2](450/8110): Loss: 0.9923
lossMSE: 0.9624587893486023 lossSSIM: 1.001754879951477
===> Epoch[2](500/8110): Loss: 0.9919
lossMSE: 0.9803807139396667 lossSSIM: 1.0016776323318481
===> Epoch[2](550/8110): Loss: 0.9964
lossMSE: 0.9739828705787659 lossSSIM: 1.0017040967941284
===> Epoch[2](600/8110): Loss: 0.9948
lossMSE: 0.9747862815856934 lossSSIM: 1.0017465353012085
===> Epoch[2](650/8110): Loss: 0.9950
lossMSE: 0.9605003595352173 lossSSIM: 1.0016688108444214
===> Epoch[2](700/8110): Loss: 0.9914
lossMSE: 0.9735003113746643 lossSSIM: 1.0017071962356567
===> Epoch[2](750/8110): Loss: 0.9947
lossMSE: 0.9638087153434753 lossSSIM: 1.0017224550247192
===> Epoch[2](800/8110): Loss: 0.9922
lossMSE: 0.9630942344665527 lossSSIM: 1.001712441444397
===> Epoch[2](850/8110): Loss: 0.9921
lossMSE: 0.942817747592926 lossSSIM: 1.0017167329788208
===> Epoch[2](900/8110): Loss: 0.9870
lossMSE: 0.9611804485321045 lossSSIM: 1.00173461437225

In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms, utils
def test():
    avg_psnr = 0
    with torch.no_grad():
        for iteration, batch in enumerate(dataloader):
            input, target = batch['input'].to(device), batch['output'].to(device)
                
            prediction = model(input)

            target = target.float()

            outdataloader = {'input':prediction,'output':target}
            
            print(outdataloader['input'].size(), 
                      outdataloader['output'].size())

            plt.figure()
            input_batch, label_batch = outdataloader['input'], outdataloader['output']
            batch_size = 32
            im_size = input_batch.size(2)
            label_batch=label_batch.reshape([batch_size,1,im_size,im_size])
            print(label_batch.size())
            for img in input_batch:
                for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
                    t.mul_(s).add_(m)
                            
            for img in label_batch:
                for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
                    t.mul_(s).add_(m)                           

            grid = utils.make_grid(input_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))
            plt.figure()

            grid = utils.make_grid(label_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))

            plt.axis('off')
            plt.ioff()
            plt.show()
            
            targetf = target[:, None]
            
            lossMSE = criterionMSE(prediction, targetf)      
            lossSSIM = -ssim(prediction, targetf)
        
            p = 0.25
            loss = p*lossMSE + (1-p)*lossSSIM
            combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
#             mse = criterion(prediction, target.float())

            psnr = 10 * torch.log10(1 / loss)
            avg_psnr += psnr
            if iteration == 16:
                break
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(dataloader)))

In [None]:
test()

In [None]:
# restart
for epoch in range(1, 5 + 1):
    train(epoch)
#     test()
#     checkpoint(epoch)