In [1]:
from __future__ import print_function, division
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import datanew as dt
import model_4_op as md
import copy
from pytorch_ssim import ssim
from torch.utils.data import DataLoader
# e/d + i

In [2]:
# print gpu
torch.cuda.set_device(0)
currentDevice = torch.cuda.current_device()
print(torch.cuda.get_device_name(0))
print(currentDevice)
print("Current GPU: " + str(currentDevice))
print(str(torch.cuda.device_count()))
print(str(torch.cuda.get_device_capability(currentDevice)))
print(torch.__version__)

TITAN V
0
Current GPU: 0
8
(7, 0)
1.0.0


In [3]:
USE_GPU = 1
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = "cpu"
print(device)
torch.cuda.empty_cache()

cuda:0


In [None]:
# generate csv file, run only for the first time
# dt.generate_csv()

In [5]:
csvFilePath = dt.get_csv_path()
transformed_dataset = dt.HE_SHG_Dataset(csv_file=csvFilePath,
                                               transform=dt.Compose([                                              
                                               dt.Rescale(96),                                     
                                               dt.Normalize(),
                                               dt.ToTensor()
                                           ]))

# for testing
transformed_dataset_raw = dt.HE_SHG_Dataset(csv_file=csvFilePath,
                                               transform=dt.Compose([                                              
                                               dt.Rescale(96),                                     
                                               dt.ToTensor()
                                           ]))

In [6]:
# batchsize 32->16
dataloader = DataLoader(transformed_dataset, batch_size=32,
                        shuffle=True, num_workers=0)

# for testing
dataloader_raw = DataLoader(transformed_dataset_raw, batch_size=32,
                        shuffle=True, num_workers=0)

In [None]:
# TODO: insert back mean and variance to plot the image appropriately
dt.show_patch(dataloader) 

In [None]:
dt.show_patch(dataloader_raw) 

In [None]:
sample_test = dt.get_one_batch(dataloader_raw)

In [None]:
sample, meanHE, stdHE, meanSHG, stdSHG = dt.normalizebatch(sample_test)

In [None]:
dt.show_one_batch(sample, meanHE, stdHE, meanSHG, stdSHG)

In [7]:
print('===> Building model')
model = md.Net().to(device)
criterionMSE = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)


===> Building model


In [8]:
def train(epoch, p, windowsize):
    epoch_loss = 0
    for iteration, batch in enumerate(dataloader):
        input, target = batch['input'].to(device), batch['output'].to(device)

        optimizer.zero_grad()
        output = model(input)
        targetf = target.float()
        targetf = targetf[:, None]
        
        lossMSE = criterionMSE(output, targetf)      
        lossSSIM = 1-ssim(output, targetf, window_size=windowsize)
        
        # 0.75->0.4 after 3 epochs
        loss = p*lossMSE + (1-p)*lossSSIM
        combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
        
        epoch_loss = epoch_loss + combineLoss
        loss.backward()
        optimizer.step()
    
        if iteration%50 == 0:
            print("lossMSE: " + str(lossMSE.item()) +
                  " " + "lossSSIM: " + str(lossSSIM.item()))
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(dataloader), loss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(dataloader)))


In [None]:
# def test():
#     avg_psnr = 0
#     with torch.no_grad():
#         for batch in testing_data_loader:
#             input, target = batch[0].to(device), batch[1].to(device)

#             prediction = model(input)
#             mse = criterion(prediction, target)
#             psnr = 10 * log10(1 / mse.item())
#             avg_psnr += psnr
#     print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))


In [None]:
# def checkpoint(epoch):
#     model_out_path = "model_epoch_{}.pth".format(epoch)
#     torch.save(model, model_out_path)
#     print("Checkpoint saved to {}".format(model_out_path))

In [9]:
import matplotlib.pyplot as plt
from torchvision import transforms, utils
def test():
    avg_psnr = 0
    with torch.no_grad():
        for iteration, batch in enumerate(dataloader_raw):
            batch, meanHE, stdHE, meanSHG, stdSHG = dt.normalizebatch(batch)
            input, target = batch['input'].to(device), batch['output'].to(device)
                
            prediction = model(input)

            target = target.float()

            outdataloader = {'input':prediction,'output':target}
            
            print(outdataloader['input'].size(), 
                      outdataloader['output'].size())

            plt.figure()
            input_batch, label_batch = outdataloader['input'], outdataloader['output']
            batch_size = 32
            im_size = input_batch.size(2)
            label_batch=label_batch.reshape([batch_size,1,im_size,im_size])
            print(label_batch.size())
#             for img in input_batch:
#                 for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
#                     t.mul_(s).add_(m)
                            
#             for img in label_batch:
#                 for t, m, s in zip(img, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):
#                     t.mul_(s).add_(m)                           

            grid = utils.make_grid(input_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))
            plt.figure()

            grid = utils.make_grid(label_batch).cpu()
            plt.imshow(grid.numpy().transpose((1, 2, 0)))

            plt.axis('off')
            plt.ioff()
            plt.show()
            
            targetf = target[:, None]
            
            lossMSE = criterionMSE(prediction, targetf)      
            lossSSIM = -ssim(prediction, targetf)
        
            p = 0.25
            loss = p*lossMSE + (1-p)*lossSSIM
            combineLoss = p*lossMSE.item() + (1-p)*lossSSIM.item()
#             mse = criterion(prediction, target.float())

            psnr = 10 * torch.log10(1 / loss)
            avg_psnr += psnr
            if iteration == 16:
                break
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(dataloader)))

In [None]:
l = 0.001
p = 0.75
windowsize = 4
for epoch in range(1, 20 + 1):
    if epoch%5 == 0:
        windowsize = windowsize+1
        p = p*0.5
        l = l*0.8
        if windowsize == 0:
            windowsze = 1
    optimizer = optim.Adam(model.parameters(), lr=l)
        
    train(epoch, p, windowsize)

lossMSE: 0.309888631105423 lossSSIM: 0.9932149052619934
===> Epoch[1](0/25343): Loss: 0.4807
lossMSE: 0.21992065012454987 lossSSIM: 0.9389750361442566
===> Epoch[1](50/25343): Loss: 0.3997
lossMSE: 0.22309663891792297 lossSSIM: 0.9399890899658203
===> Epoch[1](100/25343): Loss: 0.4023
lossMSE: 0.21575605869293213 lossSSIM: 0.9338243007659912
===> Epoch[1](150/25343): Loss: 0.3953
lossMSE: 0.22703726589679718 lossSSIM: 0.9500573873519897
===> Epoch[1](200/25343): Loss: 0.4078
lossMSE: 0.2206663340330124 lossSSIM: 0.9433647990226746
===> Epoch[1](250/25343): Loss: 0.4013
lossMSE: 0.21714414656162262 lossSSIM: 0.9406930208206177
===> Epoch[1](300/25343): Loss: 0.3980
lossMSE: 0.2099798023700714 lossSSIM: 0.9368406534194946
===> Epoch[1](350/25343): Loss: 0.3917
lossMSE: 0.2133486568927765 lossSSIM: 0.9387649297714233
===> Epoch[1](400/25343): Loss: 0.3947
lossMSE: 0.21111668646335602 lossSSIM: 0.9344047904014587
===> Epoch[1](450/25343): Loss: 0.3919
lossMSE: 0.20733529329299927 lossSSIM:

In [None]:
import os
cwd = os.getcwd();
path = os.path.join(cwd, 'Saved model', 'encoderresinfo.pth')
torch.save(model.state_dict(), path)

In [None]:
l = 0.0001
p = 0
windowsize = 4
for epoch in range(1, 20 + 1):
    if epoch%5 == 0:
        windowsize = windowsize+1
        p = p*1
        l = l*0.2
        if windowsize == 0:
            windowsze = 1
    optimizer = optim.Adam(model.parameters(), lr=l)
        
    train(epoch, p, windowsize)