In [None]:
#!/usr/bin/env python
from __future__ import print_function, division
import os
import time
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from os.path import join
from utils.ssim import SSIM, MSSSIM
from parameters import Parameters
from utils.utils import adjust_learning_rate
import matplotlib.pyplot as plt

from utils.dataset import create_my_data_1ch, BasicDataset, create_my_data_3ch
from loupe.models import loupe_1ch

from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import normalized_root_mse
from skimage.metrics import structural_similarity
from skimage.metrics import mean_squared_error

seed_num = 42
torch.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)
np.random.seed(seed_num)
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [None]:
params = Parameters()
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
params.epochs = 5 # 40
params.batch_size = 32 # 128
params.lr = 0.01 # 0.01

In [None]:
# Project config
model_name = params.model_name
print("model_name:",model_name)
num_epoch = int(params.epochs)
batch_size = int(params.batch_size)
ssimCriterion = SSIM()
msssimCriterion = MSSSIM()

In [None]:
# Configure directory info
run_name = model_name+"_bs_"+str(params.batch_size) + "_ep_"+str(params.epochs) + "_lr_" + str(params.lr)
save_dir = join(params.save_weights, run_name)
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
global_step = 0

In [None]:
trainx, trainy, valx, valy, train_mean, train_std = create_my_data_1ch(params)

In [None]:
n_train = int(len(trainx))
n_val = int(len(valx))
dataset_train = BasicDataset(trainx, trainy)
dataset_val = BasicDataset(valx, valy)
train_loader = DataLoader(dataset_train, batch_size=params.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(dataset_val, batch_size=params.batch_size, shuffle=False, num_workers=0)

In [None]:
rec_net = loupe_1ch(2,2)
#density_compensation = torch.from_numpy(density_compensation).cuda().requires_grad_(True)
criterion = torch.nn.L1Loss().cuda()
optimizer = optim.Adam(rec_net.parameters(), lr=float(params.lr), betas=(0.5, 0.999))
if cuda:
    rec_net = rec_net.cuda()
rec_net.train()
best_loss = 9999
best_model_name = ''

In [None]:
for epoch in range(num_epoch):
    rec_net.train()
    optimizer = adjust_learning_rate(epoch, optimizer)
    count = 1
    t_loss = 0
    n_loss = 0
    r_loss = 0
    with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{num_epoch}', unit='img') as pbar:
        for batch in train_loader:
            optimizer.zero_grad()

            img_un = batch['img_un'].cuda().float()
            
            img_full_real = batch["img_full_real"].cuda().float()

            img_full_real_input = img_full_real#.repeat(1,3,1,1)
            
            out_pred, out_mask, out_prob = rec_net(img_un)

            pred_loss1= criterion(out_pred[:,0,:,:],img_full_real_input[:,0,:,:])
            
            loss = 0
            loss = pred_loss1

            #t_loss = t_loss + loss.item()
            count += 1
            #writer.add_scalar('train/Loss', loss.item(), global_step)
            pbar.set_postfix({'loss': loss})
  
            loss.backward()
            #n_loss.backward()

            optimizer.step()
            pbar.update(img_un.shape[0])
            global_step += 1
    if epoch % 10 == 0:
        pred_img = out_pred.detach().cpu().numpy()
        #print("pred_img.shape",pred_img.shape)
        plt.figure()
        plt.imshow(pred_img[0,0,:,:], cmap='gray')
        #plt.savefig(str(epoch)+"_img.png") 

In [None]:
rec_net.eval()
psnr_total = 0
nmse_total = 0
rmse_total = 0
ssim_total = 0
num = 0
with tqdm(total=n_val, desc=f'Epoch {epoch + 1}/{num_epoch}', unit='img') as pbar:
    for batch in val_loader:
        img_un = batch['img_un'].cuda().float()
        img_full_real = batch["img_full_real"].cuda().float()
        img_full_real_input = img_full_real#.repeat(1,3,1,1)
        out_pred, out_mask, out_prob = rec_net(img_un)

        pred_img = out_pred.detach().cpu().numpy()
        real_img = img_full_real_input.detach().cpu().numpy()

        for i in range(pred_img.shape[0]):
            psnr_total += peak_signal_noise_ratio(real_img[i][0],pred_img[i][0])

            nmse_total += normalized_root_mse(real_img[i][0],pred_img[i][0])

            rmse_total += np.sqrt(mean_squared_error(real_img[i][0],pred_img[i][0]))

            ssim_total += structural_similarity(real_img[i][0],pred_img[i][0],data_range=1)
            num +=1

        pbar.update(img_un.shape[0])

psnr_mean = psnr_total / n_val
nmse_mean = nmse_total / n_val
rmse_mean = rmse_total / n_val
ssim_mean = ssim_total / n_val

print("BASELINE PSNR:%s SSIM:%s NMSE:%s RMSE:%s" % (psnr_mean,ssim_mean,nmse_mean,rmse_mean))

In [None]:
del trainx,trainy,valx,valy,dataset_train,dataset_val,train_loader,val_loader

In [None]:
trainx, trainy, valx, valy, train_mean, train_std = create_my_data_3ch(params)

In [None]:
n_train = int(len(trainx))
n_val = int(len(valx))
dataset_train = BasicDataset(trainx, trainy)
dataset_val = BasicDataset(valx, valy)
train_loader = DataLoader(dataset_train, batch_size=params.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(dataset_val, batch_size=params.batch_size, shuffle=False, num_workers=0)

In [None]:
input_data = val_loader.dataset[-1]['img_un']
ref_data = val_loader.dataset[-1]['img_full_real']

In [None]:
input1 = torch.stack((input_data[0],input_data[3]))[np.newaxis,...].cuda().float()
input2 = torch.stack((input_data[1],input_data[4]))[np.newaxis,...].cuda().float()
input3 = torch.stack((input_data[2],input_data[5]))[np.newaxis,...].cuda().float()

ref1 = ref_data[0].cuda().float()
ref2 = ref_data[1].cuda().float()
ref3 = ref_data[2].cuda().float()

In [None]:
out_pred_1, out_mask_1, out_prob = rec_net(input1)
out_pred_2, out_mask_2, out_prob = rec_net(input2)
out_pred_3, out_mask_3, out_prob = rec_net(input3)

In [None]:
prob_mask = np.squeeze(np.array((out_mask_1.detach().cpu().numpy(),out_mask_2.detach().cpu().numpy(),out_mask_3.detach().cpu().numpy())))

In [None]:
pred_img = np.squeeze(np.array((out_pred_1.detach().cpu().numpy(),out_pred_2.detach().cpu().numpy(),out_pred_3.detach().cpu().numpy())))

In [None]:
np.save('1 channel, slope=200, sample_slope=200, 40iters, lr=0.01, sparsity=0.125, lambda=0.001.npy', [prob_mask,pred_img])

In [None]:
torch.save(rec_net.state_dict(), '1 channel, slope=200, sample_slope=200, 40iters, lr=0.01, sparsity=0.125, lambda=0.001.pth')