In [1]:
from __future__ import print_function, division
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import data as dt
import model as md
import copy
from pytorch_ssim import ssim
from torch.utils.data import DataLoader

In [2]:
# print gpu
torch.cuda.set_device(3)
currentDevice = torch.cuda.current_device()
print("Current GPU: " + str(currentDevice))
print(str(torch.cuda.device_count()))
print(str(torch.cuda.get_device_capability(currentDevice)))
print(torch.__version__)

Current GPU: 3
8
(6, 1)
1.0.0


In [3]:
USE_GPU = 1
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda:3')
else:
    device = "cpu"
print(device)

cuda:3


In [None]:
# generate csv file, run only for the first time
# dt.generate_csv()

In [4]:
csvFilePath = dt.get_csv_path()
transformed_dataset = dt.HE_SHG_Dataset(csv_file=csvFilePath,
                                               transform=dt.Compose([                                              
                                               dt.Rescale(96),
                                               dt.ToTensor(),
                                               dt.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])                                           
                                           ]))
# TODO: change the normalization parameters

In [5]:
# batchsize 32->16
dataloader = DataLoader(transformed_dataset, batch_size=32,
                        shuffle=True, num_workers=0)

In [None]:
# TODO: insert back mean and variance to plot the image appropriately
dt.show_patch(dataloader) 

In [6]:
print('===> Building model')
model = md.Net().to(device)
criterionMSE = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.1)


===> Building model


In [7]:
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(dataloader):
        input, target = batch['input'].to(device), batch['output'].to(device)

        optimizer.zero_grad()
        output = model(input)
        targetf = target.float()
        targetf = targetf[:, None]
        
        lossMSE = criterionMSE(output, targetf)      
        lossSSIM = -(ssim(output, targetf)-1)
        
        p = 0.25
        loss = p*lossMSE + (1-p)*lossSSIM
        combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
        
        epoch_loss = epoch_loss + combineLoss
        loss.backward()
        optimizer.step()
    
        if iteration%50 == 0:
            print("lossMSE: " + str(lossMSE.item()) +
                  " " + "lossSSIM: " + str(lossSSIM.item()))
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(dataloader), loss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(dataloader)))


In [None]:
# def test():
#     avg_psnr = 0
#     with torch.no_grad():
#         for batch in testing_data_loader:
#             input, target = batch[0].to(device), batch[1].to(device)

#             prediction = model(input)
#             mse = criterion(prediction, target)
#             psnr = 10 * log10(1 / mse.item())
#             avg_psnr += psnr
#     print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))


In [None]:
# def checkpoint(epoch):
#     model_out_path = "model_epoch_{}.pth".format(epoch)
#     torch.save(model, model_out_path)
#     print("Checkpoint saved to {}".format(model_out_path))

In [None]:
for epoch in range(1, 10 + 1):
    train(epoch)
#     test()
#     checkpoint(epoch)

lossMSE: 2.2096481323242188 lossSSIM: 1.0789344310760498
===> Epoch[1](0/25343): Loss: 1.3616
lossMSE: 2.1858043670654297 lossSSIM: 1.0360506772994995
===> Epoch[1](50/25343): Loss: 1.3235
lossMSE: 2.1437644958496094 lossSSIM: 1.034879207611084
===> Epoch[1](100/25343): Loss: 1.3121
lossMSE: 2.0976638793945312 lossSSIM: 1.0340828895568848
===> Epoch[1](150/25343): Loss: 1.3000
lossMSE: 2.1018261909484863 lossSSIM: 1.0319061279296875
===> Epoch[1](200/25343): Loss: 1.2994
lossMSE: 2.0222904682159424 lossSSIM: 1.0291587114334106
===> Epoch[1](250/25343): Loss: 1.2774
lossMSE: 2.0070016384124756 lossSSIM: 1.0280711650848389
===> Epoch[1](300/25343): Loss: 1.2728
lossMSE: 1.9654513597488403 lossSSIM: 1.0300029516220093
===> Epoch[1](350/25343): Loss: 1.2639
lossMSE: 1.9446958303451538 lossSSIM: 1.0274171829223633
===> Epoch[1](400/25343): Loss: 1.2567
lossMSE: 1.7757505178451538 lossSSIM: 1.0261132717132568
===> Epoch[1](450/25343): Loss: 1.2135
lossMSE: 1.8746812343597412 lossSSIM: 1.0288

lossMSE: 1.1003761291503906 lossSSIM: 1.0047268867492676
===> Epoch[1](4300/25343): Loss: 1.0286
lossMSE: 1.049142599105835 lossSSIM: 1.004579782485962
===> Epoch[1](4350/25343): Loss: 1.0157
lossMSE: 1.0749120712280273 lossSSIM: 1.004716157913208
===> Epoch[1](4400/25343): Loss: 1.0223
lossMSE: 1.0845999717712402 lossSSIM: 1.0047023296356201
===> Epoch[1](4450/25343): Loss: 1.0247
lossMSE: 1.1119210720062256 lossSSIM: 1.004279613494873
===> Epoch[1](4500/25343): Loss: 1.0312
lossMSE: 1.0710504055023193 lossSSIM: 1.0045636892318726
===> Epoch[1](4550/25343): Loss: 1.0212
lossMSE: 1.100809097290039 lossSSIM: 1.0044628381729126
===> Epoch[1](4600/25343): Loss: 1.0285
lossMSE: 1.05734384059906 lossSSIM: 1.0042668581008911
===> Epoch[1](4650/25343): Loss: 1.0175
lossMSE: 1.1011220216751099 lossSSIM: 1.0044922828674316
===> Epoch[1](4700/25343): Loss: 1.0286
lossMSE: 1.0729526281356812 lossSSIM: 1.0040820837020874
===> Epoch[1](4750/25343): Loss: 1.0213
lossMSE: 1.1010494232177734 lossSSIM:

In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms, utils
def test():
    avg_psnr = 0
    with torch.no_grad():
        for iteration, batch in enumerate(dataloader):
            input, target = batch['input'].to(device), batch['output'].to(device)
                
            prediction = model(input)

            target = target.float()

            outdataloader = {'input':prediction,'output':target}
            
            print(outdataloader['input'].size(), 
                      outdataloader['output'].size())

            plt.figure()
            input_batch, label_batch = outdataloader['input'], outdataloader['output']
            batch_size = 32
            im_size = input_batch.size(2)
            label_batch=label_batch.reshape([batch_size,1,im_size,im_size])
            print(label_batch.size())
            for img in input_batch:
                for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
                    t.mul_(s).add_(m)
                            
            for img in label_batch:
                for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
                    t.mul_(s).add_(m)                           

            grid = utils.make_grid(input_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))
            plt.figure()

            grid = utils.make_grid(label_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))

            plt.axis('off')
            plt.ioff()
            plt.show()
            
            targetf = target[:, None]
            
            lossMSE = criterionMSE(prediction, targetf)      
            lossSSIM = -ssim(prediction, targetf)
        
            p = 0.25
            loss = p*lossMSE + (1-p)*lossSSIM
            combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
#             mse = criterion(prediction, target.float())

            psnr = 10 * torch.log10(1 / loss)
            avg_psnr += psnr
            if iteration == 16:
                break
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(dataloader)))

In [None]:
test()

In [None]:
# restart
for epoch in range(1, 5 + 1):
    train(epoch)
#     test()
#     checkpoint(epoch)