In [None]:
import numpy as np
import matplotlib.pyplot as plt
import time
import tqdm
import random
from PIL import Image
import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision import models, transforms
import glob
import os
import seaborn as sb
import copy
import pandas as pd

import tifffile as tiff

#here we have convolution theorem(using fast fourier transform)
from python_files.fft_conv import fft_conv

#changes crop in dataset and dataset folder and saving

device = 'cpu' if torch.cuda.is_available() else 'cpu'

In [None]:
#make dirs for model saving
test_dir = f'test_reconstruction'
test_lcd_dir = f'test_lcd'
test_b_dir = f'test_b_reconstruction'
test_visual_dir = f'test_visual'
os.makedirs(test_dir)
os.makedirs(test_b_dir)
os.makedirs(test_visual_dir)
os.makedirs(test_lcd_dir)

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, train=True):
        self.image_paths = image_paths

    def transform(self, image):
        # Transform to tensor
        #cr = T.CenterCrop((1080, 1920))
        #image = cr(image)
        image = TF.to_tensor(image)
        
        return image

    def __getitem__(self, index):
        #image = tiff.imread(self.image_paths[index])
        image = (tiff.imread(self.image_paths[index])).astype('float32')
        x = self.transform(image)
        return x, self.image_paths[index]

    def __len__(self):
        return len(self.image_paths)

In [None]:
#test_data = glob.glob('train/data/*.tiff')# + glob.glob('DIV2K_valid_HR/*.png')
test_data = glob.glob('srgb_linearized_f32/*.tif')
test_data = test_data

test_dataset = Dataset(sorted(test_data))

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=3, drop_last=True
)

In [None]:
len(test_dataloader), len(test_dataset)

In [None]:
import io
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sb

PSF_h = 1025
PSF_w = 1025

file = open("zerman_led_psf.csv")
PSF = torch.tensor(np.loadtxt(file, delimiter=",")).reshape(1, 1, PSF_h, PSF_w).to(device)
print(f"Mean of PSF:{PSF.mean()}")
sb.displot(PSF.cpu().numpy().flatten())
plt.xlim(-0.05, 0.15)

In [None]:
import json
 
# Opening JSON file
f = open('Zermans_LEDS_POSITION_2202.json')
 
# returns JSON object as
# a dictionary
data = json.load(f)
b_mask_coords = np.array(data['grid_coords'])
b_mask = torch.zeros(1, 1, 1080, 1920).to(device)
#putting values to leds positions
b_mask[0, 0][(b_mask_coords - 1).T] = 1.0

In [None]:
#here we have all leds on
b_psf_full = torch.ones(1, 1, 1080, 1920).type(torch.DoubleTensor).to(device)
b_psf_proj = b_mask*b_psf_full
b_psf_proj.mean()*(1920*1080)/(2202), b_psf_proj.sum()

In [None]:
#flip PSF for 180 degrees
PSF_flipped = torch.flip(PSF, dims=[-1, -2])
#simulation of luminance with all leds on
#PSF_flipped = torch.nn.functional.pad(PSF_flipped, (0,1,0,1), mode='constant')
b_psf_output = fft_conv(180*b_psf_proj, PSF_flipped, padding=512)

#here we use 24 because b_psf_output have 24.5(beta)
#now 18.5
loss_function_mask = b_psf_output > 2700
loss_function_mask = loss_function_mask.expand(1, 3, 1080, 1920)

In [None]:
import json

# Opening JSON file
f = open('hex_coords.json')
 
# returns JSON object as
# a dictionary
data = json.load(f)
hex_coords = np.array(data['hex_coords'])

In [None]:
from python_files.utils import get_transmittance, tensor_masking

In [None]:
def max_reconstruct_original_image(I_or: torch.tensor) -> torch.tensor:
    B = torch.zeros(1, 1, 1080, 1920).to(device)
    eps = 1e-15
    
    summer = 0
    #print("Start Iterations:")
    for i, tup0 in enumerate(b_mask_coords):
        new_coords = np.copy(hex_coords)
        new_coords[:, 0] = np.maximum(np.minimum(new_coords[:, 0] - 1 + tup0[0], 1079), 0)
        new_coords[:, 1] = np.maximum(np.minimum(new_coords[:, 1] - 1 + tup0[1], 1919), 0)
        
        B[0, 0, tup0[0] - 1, tup0[1] - 1] = (I_or[0].mean(dim=0)[new_coords.T].max())/4000
        #print(f"Iterations: {i}", end='\r')
    #print("End Iterations" + 70 * ' ')
        
    B_lbd = fft_conv(180*B, PSF_flipped, padding=512)
    B_lbd = torch.clip(B_lbd, 0.0, B_lbd.max().item())
    B_lbd[B_lbd == 0] = eps
    LCD = torch.clip((I_or)/B_lbd, 0.0, 1.0)
    #LCD = gamma_correction(LCD)
    T = get_transmittance(LCD, 0.005)

    #always between 0.0 and 1.0
    I_re = T*B_lbd
    
    return I_re, LCD, B_lbd, B

In [None]:
def avg_reconstruct_original_image(I_or: torch.tensor) -> torch.tensor:
    B = torch.zeros(1, 1, 1080, 1920).to(device)
    eps = 1e-15
    
    #print("Start Iterations:")
    for i, tup0 in enumerate(b_mask_coords):
        new_coords = np.copy(hex_coords)
        new_coords[:, 0] = np.maximum(np.minimum(new_coords[:, 0] - 1 + tup0[0], 1079), 0)
        new_coords[:, 1] = np.maximum(np.minimum(new_coords[:, 1] - 1 + tup0[1], 1919), 0)
        
        B[0, 0, tup0[0] - 1, tup0[1] - 1] = (I_or[0].mean(dim=0)[new_coords.T].mean()/4000)
        #print(f"Iterations: {i}", end='\r')
    #print("End Iterations" + 70 * ' ')
        
    B_lbd = fft_conv(180*B, PSF_flipped, padding=512)
    B_lbd = torch.clip(B_lbd, 0.0, B_lbd.max().item())
    B_lbd[B_lbd == 0] = eps
    LCD = torch.clip((I_or)/B_lbd, 0.0, 1.0)
    #LCD = gamma_correction(LCD)
    T = get_transmittance(LCD, 0.005)

    #always between 0.0 and 1.0
    I_re = T*B_lbd
    
    return I_re, LCD, B_lbd, B

In [None]:
b_mask_coords.T.shape

In [None]:
def gd_reconstruct_original_image(I_or: torch.tensor) -> torch.tensor:
    B = torch.zeros(1, 1, 1080, 1920).to(device)
    eps = 1e-15

    B[0, 0][(b_mask_coords - 1).T] = (I_or.max())/4000
        
    B_lbd = fft_conv(180*B, PSF_flipped, padding=512)
    B_lbd = torch.clip(B_lbd, 0.0, B_lbd.max().item())
    B_lbd[B_lbd == 0] = eps
    LCD = torch.clip((I_or)/B_lbd, 0.0, 1.0)
    #LCD = gamma_correction(LCD)
    T = get_transmittance(LCD, 0.005)

    #always between 0.0 and 1.0
    I_re = T*B_lbd
    
    return I_re, LCD, B_lbd, B

In [None]:
def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']

In [None]:
from python_files.pu21_encoder import pu21_encoder
from skimage.metrics import peak_signal_noise_ratio as PSNR

def PU_PSNR(true, pred):
    pu21 = pu21_encoder(0.005, 4000)
    true = pu21.forward(true.cpu())
    pred = pu21.forward(pred.cpu())
    return PSNR(true.cpu().data.numpy(), pred.cpu().data.numpy(), data_range=256)

In [None]:
def loss_function(I_, I):
    """I_ -- prediction, I -- target, 
       pa is power_parameter
    """
    #data preprocessing
    IM_ = I_[loss_function_mask]
    IM  = I[loss_function_mask]
    
    # L1 smooth loss
    loss_f = torch.nn.SmoothL1Loss(beta=1.0)
    
    loss = loss_f(IM_, IM)

    return loss

In [None]:
from pytorch_msssim import MS_SSIM

def PU_MSSSIM(true, pred):
    pu21 = pu21_encoder(0.005, 4000)
    ms_ssim_module = MS_SSIM(data_range=256, size_average=True, channel=3)
    true = pu21.forward(true.cpu())
    pred = pu21.forward(pred.cpu())
    return ms_ssim_module(true.cpu().data, pred.cpu().data)

In [None]:
from skimage import io
import time

def test_model(loss_func):
    dataloader = test_dataloader

    running_loss = 0.
    running_pu_msssim = 0.
    running_pu_psnr = 0.
    running_perf_time = 0
    
    
    running_backlight = 0.
    # Iterate over data.
    for i, (inputs, pname) in enumerate(dataloader):
        max_luminance = 4000
        inputs = inputs.to(device)

        #time synchronize since cuda is asynchronous
        torch.cuda.synchronize()
        start_time = time.time()
        
        #reconstruction
        constructed_preds, LCD, B_lbd, preds = max_reconstruct_original_image(max_luminance*inputs)
        
        #same as above
        torch.cuda.synchronize()
        nn_work_time = time.time() - start_time
        running_perf_time += nn_work_time
        
        
        
        #loss_value and metrics
        loss_value = loss_func(max_luminance*inputs, constructed_preds)
        pu_msssim = PU_MSSSIM(max_luminance*inputs, constructed_preds)
        pu_psnr = PU_PSNR(max_luminance*inputs, constructed_preds)
        
        #statistics
        running_loss += loss_value.item()
        running_pu_msssim += pu_msssim.item()
        running_pu_psnr += pu_psnr.item()

        #print("{} nested {}/{} {}-loss: {:.5f}, PU_MSSSIM: {:.5f}, PU_PSNR {:.1f}" \
              #.format(pname, (i+1), len(dataloader), 
                      #'test', loss_value.item(), pu_msssim.item(), pu_psnr.item()))
        strip_index = pname[0].rfind('/') + 1 
        dot_index = pname[0][strip_index:].rfind('.')
        
        output_str = "{} loss: {:.2f}, PU_MSSSIM: {:.2f}, PU_PSNR {:.1f}, BACKLIGH_SUM: {:.2f}, MIN_VALUE_RESULT {:.6f}, MAX_VALUE_RESULT {:.2f}, MIN_VALUE_REAL {:.6f}, MAX_VALUE_REAL {:.2f}" \
                  .format(pname[0][strip_index:], loss_value.item(), pu_msssim.item(), pu_psnr.item(), (b_mask*preds).sum().item(),
                          B_lbd.min().item(), B_lbd.max().item(), inputs.min().item(), inputs.max().item())
        running_backlight += (b_mask*preds).sum().item()
        with open(f"test_evaluation.txt", "a+") as f:
                    f.write(output_str + '\n')
                
        name = pname[0][strip_index:]
                
        #saving lcd
        pil_trans = transforms.ToPILImage()
        pil_lcd = pil_trans(LCD[0])
        pil_lcd.save(os.path.join(test_lcd_dir, name + '.lcd.dld.tif'), compression=None, quality=100)

        #saving
        #tiff.imsave(os.path.join(test_dir, name + '.out.dld.tif'),
        #            (constructed_preds).detach().cpu().numpy())
        #low_format
        tiff.imsave(os.path.join(test_visual_dir, name + '.out_uint8.dld.tif'),
                   (constructed_preds*255).detach().cpu().numpy().astype(np.uint8))
        
        #reconstruction
        tiff.imsave(os.path.join(test_dir, name + '.out_sim.dld.tif'),
                    (constructed_preds).detach().cpu().numpy())

    #mean_value on epoch
    epoch_loss = running_loss / len(dataloader)
    epoch_pu_msssim = running_pu_msssim/ len(dataloader)
    epoch_pu_psnr = running_pu_psnr/ len(dataloader)
    epoch_perf_time = running_perf_time / len(dataloader)
    epoch_backlight = running_backlight / len(dataloader)

    print('{}. {} Loss: {:.4f}, PU_MSSSIM: {:.4f}, PU_PSNR: {:.2f}, PERF_TIME: {:.4f}, BACKLIGHT_MEAN_SUM: {:.2f}'.format(
        0, 'test', epoch_loss, epoch_pu_msssim, epoch_pu_psnr, epoch_perf_time, epoch_backlight) + 70*' ')

In [None]:
test_model(loss_function)
#AVG
#PU_MSSSIM: 0.9172, PU_PSNR: 21.51, PERF_TIME: 4.5909, BACKLIGHT_MEAN_SUM: 1.69
#MAX
#PU_MSSSIM: 0.9813, PU_PSNR: 27.74, PERF_TIME: 4.6832, BACKLIGHT_MEAN_SUM: 2.97
#GD
#PU_MSSSIM: 0.9584, PU_PSNR: 24.97, PERF_TIME: 0.1358, BACKLIGHT_MEAN_SUM: 9.37    

In [None]:
import shutil

#shutil.rmtree(test_dir)
#shutil.rmtree(test_visual_dir)
shutil.rmtree(test_b_dir)
shutil.rmtree(test_lcd_dir)

In [None]:
#shutil.make_archive('tests', 'zip', 'test')

In [None]:
shutil.make_archive(test_dir, 'zip', 'test_reconstruction')

In [None]:
shutil.make_archive(test_visual_dir, 'zip', 'test_visual')

In [None]:
#shutil.make_archive('tests_b_reconstruction', 'zip', 'test_b_reconstruction')