In [None]:
import numpy as np
import matplotlib.pyplot as plt
import time
import tqdm
import random
from PIL import Image
import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision import models, transforms
import glob
import os
import seaborn as sb
import copy
import pandas as pd

import tifffile as tiff

#here we have convolution theorem(using fast fourier transform)
from python_files.fft_conv import fft_conv

#changes crop in dataset and dataset folder and saving

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
device

In [None]:
#make dirs for model saving
test_dir = f'test_reconstruction'
test_lcd_dir = f'test_lcd'
test_b_dir = f'test_b_reconstruction'
test_visual_dir = f'test_visual'
os.makedirs(test_dir)
os.makedirs(test_b_dir)
os.makedirs(test_visual_dir)
os.makedirs(test_lcd_dir)

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, train=True):
        self.image_paths = image_paths

    def transform(self, image):
        # Transform to tensor
        #cr = T.CenterCrop((1080, 1920))
        #image = cr(image)
        image = TF.to_tensor(image)
        
        return image

    def __getitem__(self, index):
        image = tiff.imread(self.image_paths[index])
        #image = Image.open(self.image_paths[index])
        x = self.transform(image)
        return x, self.image_paths[index]

    def __len__(self):
        return len(self.image_paths)

In [None]:
#test_data = glob.glob('small_train/*.tif') #+ glob.glob('DIV2K_valid_HR/*.png')
test_data = glob.glob('srgb_linearized_f32/*.tif')
#test_data.pop(0)
#test_data.pop(0)
#test_data.pop(2)

test_dataset = Dataset(sorted(test_data))

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True
)

In [None]:
len(test_dataloader), len(test_dataset)

In [None]:
import io
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sb

PSF_h = 1025
PSF_w = 1025

file = open("zerman_led_psf.csv")
PSF = torch.tensor(np.loadtxt(file, delimiter=",")).reshape(1, 1, PSF_h, PSF_w).to(device)
print(f"Mean of PSF:{PSF.mean()}")
sb.displot(PSF.cpu().numpy().flatten())
plt.xlim(-0.05, 0.15)

In [None]:
import json
 
# Opening JSON file
f = open('Zermans_LEDS_POSITION_2202.json')
 
# returns JSON object as
# a dictionary
data = json.load(f)
b_mask_coords = np.array(data['grid_coords'])
b_mask = torch.zeros(1, 1, 1080, 1920).to(device)
for tup in b_mask_coords:
    b_mask[0, 0, tup[0] - 1, tup[1] - 1] = 1.0

In [None]:
#here we have all leds on
b_psf_full = torch.ones(1, 1, 1080, 1920).type(torch.DoubleTensor).to(device)
b_psf_proj = b_mask*b_psf_full
b_psf_proj.mean()*(1920*1080)/(2202), b_psf_proj.sum()

In [None]:
#flip PSF for 180 degrees
PSF_flipped = torch.flip(PSF, dims=[-1, -2])
#simulation of luminance with all leds on
#PSF_flipped = torch.nn.functional.pad(PSF_flipped, (0,1,0,1), mode='constant')
b_psf_output = fft_conv(b_psf_proj, PSF_flipped, padding=512)

#here we use 24 because b_psf_output have 24.5(beta)
#now 18.5
loss_function_mask = b_psf_output > 0.0
loss_function_mask = loss_function_mask.expand(1, 3, 1080, 1920)

In [None]:
from python_files.utils import get_transmittance, tensor_masking

In [None]:
def reconstruct_original_image(B_: torch.tensor, I_or: torch.tensor) -> torch.tensor:
    B = b_mask*B_
    
    eps = 1e-15
    
    B_lbd = fft_conv(180*B, PSF_flipped, padding=512)
    B_lbd = torch.clip(B_lbd, 0.0, B_lbd.max().item())
    B_lbd[B_lbd == 0] = eps
    LCD = torch.clip((I_or)/B_lbd, 0.0, 1.0)
    #LCD = gamma_correction(LCD)
    T = get_transmittance(LCD, 0.005)

    I_re = T*B_lbd
    
    return I_re, LCD, B_lbd

In [None]:
def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']

In [None]:
from python_files.pu21_encoder import pu21_encoder
from skimage.metrics import peak_signal_noise_ratio as PSNR

def PU_PSNR(true, pred):
    pu21 = pu21_encoder(0.005, 4000)
    true = pu21.forward(true.cpu())
    pred = pu21.forward(pred.cpu())
    return PSNR(true.cpu().data.numpy(), pred.cpu().data.numpy(), data_range=256)

In [None]:
def loss_function(I_, I, preds):
    """I_ -- prediction, I -- target, 
       pa is power_parameter
    """
    #data preprocessing
    IM_ = I_[loss_function_mask]
    IM  = I[loss_function_mask]
    
    # L1 smooth loss
    loss_f = torch.nn.SmoothL1Loss(beta=1.0)
    #loss_f = torch.nn.HuberLoss(delta=512.0)
    
    loss = loss_f(IM_, IM)
    
    B = (b_mask*preds).sum()/(2202)
    
    #loss += 5*B

    return loss

In [None]:
from pytorch_msssim import MS_SSIM

def PU_MSSSIM(true, pred):
    pu21 = pu21_encoder(0.005, 4000)
    ms_ssim_module = MS_SSIM(data_range=256, size_average=True, channel=3)
    true = pu21.forward(true.cpu())
    pred = pu21.forward(pred.cpu())
    return ms_ssim_module(true.cpu().data, pred.cpu().data)

In [None]:
from skimage import io
import time

def test_model(model, loss_func):
    dataloader = test_dataloader
    model.eval()
    running_loss = 0.
    running_pu_msssim = 0.
    running_pu_psnr = 0.
    running_perf_time = 0.
    running_backlight = list()
    # Iterate over data.
    for i, (inputs, pname) in enumerate(dataloader):
        max_lum = 4000
        inputs = inputs.to(device)

        #time synchronize since cuda is asynchronous
        torch.cuda.synchronize()
        start_time = time.time()
        preds = model(inputs)
        #same as above
        torch.cuda.synchronize()
        nn_work_time = time.time() - start_time
        running_perf_time += nn_work_time
        
        #reconstruction
        constructed_preds, LCD, B_lbd = reconstruct_original_image(preds, 4000*inputs)
        
        #loss_value and metrics
        loss_value = loss_func(4000*inputs, constructed_preds, preds)
        pu_msssim = PU_MSSSIM(4000*inputs, constructed_preds)
        pu_psnr = PU_PSNR(4000*inputs, constructed_preds)
        
        #statistics
        running_loss += loss_value.item()
        running_pu_msssim += pu_msssim.item()
        running_pu_psnr += pu_psnr.item()
        running_backlight.append((b_mask*preds).sum().item())

        #print("{} nested {}/{} {}-loss: {:.5f}, PU_MSSSIM: {:.5f}, PU_PSNR {:.1f}" \
              #.format(pname, (i+1), len(dataloader), 
                      #'test', loss_value.item(), pu_msssim.item(), pu_psnr.item()))
        strip_index = pname[0].rfind('/') + 1 
        dot_index = pname[0][strip_index:].rfind('.')
        
        output_str = "{} loss: {:.2f}, PU_MSSSIM: {:.2f}, PU_PSNR {:.1f}, BACKLIGHT_SUM {:.2f}, MIN_VALUE_RESULT {:.6f}, MAX_VALUE_RESULT {:.2f}, MIN_VALUE_REAL {:.6f}, MAX_VALUE_REAL {:.2f}" \
                  .format(pname[0][strip_index:], loss_value.item(), pu_msssim.item(), pu_psnr.item(), (b_mask*preds).sum().item(), 
                          constructed_preds.min().item(), constructed_preds.max().item(), inputs.min().item(), inputs.max().item())
        with open(f"test_evaluation.txt", "a+") as f:
                    f.write(output_str + '\n')
                
        #saving lcd
        #pil_trans = transforms.ToPILImage()
        #pil_lcd = pil_trans(LCD[0])
        #pil_lcd.save(os.path.join(test_lcd_dir, pname[0][strip_index:]), compression=None, quality=100)

        #saving reconstruction
        tiff.imsave(os.path.join(test_dir, pname[0][strip_index:] + '.out_sim.dld.tif'),
                    (constructed_preds).detach().cpu().numpy())
        #visual
        tiff.imsave(os.path.join(test_visual_dir, pname[0][strip_index:] + '.out_uint8.dld.tif'),
                   ((constructed_preds/constructed_preds.max())*255).detach().cpu().numpy().astype(np.uint8))

        fig, ax = plt.subplots(1, 3, figsize=(24, 4))
        #saving backlight
        B = torch.clip(B_lbd.reshape(1080, 1920), 0.0, 1.0)
        #pil_b = pil_trans(B)
        #pil_b.save(os.path.join(test_b_dir, pname[0][strip_index:][:dot_index]) + '.png')
        
        #show_images
        ax[0].imshow((constructed_preds[0]/4000).permute(1, 2, 0).cpu().data.numpy())
        ax[1].imshow(LCD[0].permute(1, 2, 0).cpu().data.numpy())
        sb.heatmap((B_lbd.reshape(1080, 1920)).cpu().data.numpy(), ax=ax[2], linewidths=0.00, cmap='viridis')
        ax[0].grid(False)
        ax[0].axis('off')
        ax[1].grid(False)
        ax[1].axis('off')
        ax[2].grid(False)
        ax[2].axis('off')    
        
        ax[0].set_title('LOSS:{:.4f}'.format(loss_value.item()), fontsize=20)
        ax[1].set_title('PU_MSSSIM:{:.4f}'.format(pu_msssim.item()), fontsize=20)
        ax[2].set_title('PU_PSNR:{:.2f}, B_SUM: {:.1f}'.format(pu_psnr.item(), running_backlight[i]), fontsize=20)
        
        #plt.savefig(f"test_visual/{i}.jpg")
    #mean_value on epoch
    epoch_loss = running_loss / len(dataloader)
    epoch_pu_msssim = running_pu_msssim/ len(dataloader)
    epoch_pu_psnr = running_pu_psnr/ len(dataloader)
    epoch_perf_time = running_perf_time / len(dataloader)
    epoch_backlight = sum(running_backlight) / len(dataloader)

    print('{}. {} Loss: {:.4f}, PU_MSSSIM: {:.4f}, PU_PSNR: {:.2f}, BACKLIGHT_MEAN_SUM: {:.2f}, PERF_TIME: {:.4f}'.format(
        0, 'test', epoch_loss, epoch_pu_msssim, epoch_pu_psnr, epoch_backlight, epoch_perf_time) + 70*' ')

In [None]:
#1240_2(542), 
curr_dir = 'seed43124_Zerman_1.0_small'
model_list = os.listdir(curr_dir)
model_number = model_list.index('224.pt')
model_list[model_number]

In [None]:
from python_files.HDRsmall import HDRnet

model = HDRnet()
#model = torch.nn.DataParallel(model)
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
model, optimizer, start_epoch = load_ckp(os.path.join(curr_dir, model_list[model_number]), model, optimizer)
#model = model.module
model.to(device)
model.eval()

In [None]:
test_model(model, loss_function)
#110 205
#17.5656, PU_MSSSIM: 0.9854, PU_PSNR: 32.53, BACKLIGHT_MEAN_SUM: 556.49, PERF_TIME: 3.2323
#150 210
#Loss: 8.4958, PU_MSSSIM: 0.9861, PU_PSNR: 32.76, BACKLIGHT_MEAN_SUM: 669.82, PERF_TIME: 3.1619
#90 202
#Loss: 7.8785, PU_MSSSIM: 0.9865, PU_PSNR: 32.92, BACKLIGHT_MEAN_SUM: 641.58, PERF_TIME: 3.1742
#small 224
#Loss: 9.7213, PU_MSSSIM: 0.9837, PU_PSNR: 32.40, BACKLIGHT_MEAN_SUM: 618.34, PERF_TIME: 1.7222 
#large 235
#Loss: 8.7452, PU_MSSSIM: 0.9855, PU_PSNR: 32.12, BACKLIGHT_MEAN_SUM: 628.00, PERF_TIME: 1.8941

In [None]:
import math

2000**(1/1.5)

In [None]:
100**(1/1.5)

In [None]:
import shutil

shutil.rmtree(test_dir)
shutil.rmtree(test_visual_dir)
shutil.rmtree(test_b_dir)
shutil.rmtree(test_lcd_dir)

In [None]:
shutil.make_archive(test_dir, 'zip', 'test_reconstruction')

In [None]:
#shutil.make_archive(test_visual_dir, 'zip', 'test_visual')

In [None]:
import shutil

shutil.rmtree('seed43124_Zerman_test_120')