In [None]:
import numpy as np
import time
import tqdm
import random
from PIL import Image
import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import torch.nn.functional as F
import tifffile as tiff
from torchvision import models, transforms
import glob
import os
import seaborn as sb
import copy
import pandas as pd
import cv2

#here we have convolution theorem(using fast fourier transform)
from python_files.fft_conv import fft_conv

from IPython.display import clear_output
import matplotlib.pyplot as plt
#%config InlineBackend.figure_format = 'svg'

#FOR REPEAT
seed = int(random.random()*1000)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#to reproduce same result over and over
#also we lose some performance in gpu if it set to true
torch.backends.cudnn.deterministic = True

seed

In [None]:
def save_ckp(state, index, is_best):
    f_path = os.path.join(curr_dir, f"{index}.pt")
    torch.save(state, f_path)
    if is_best:
        best_fpath = os.path.join(curr_dir, 'best_model.pt')
        shutil.copyfile(f_path, best_fpath)

In [None]:
import random
intensity_lst = []
for i in range(102):
    #intensity calibration chosen uniform from [3000, 5000] 
    #and then clipped by 4000
    intensity_range = int(random.random()*3999 + 1)
    if intensity_range > 4000:
        intensity_range = 4000
    intensity_lst.append(intensity_range)

In [None]:
for i in intensity_lst:
    print(i, end=', ')

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, train=True):
        self.image_paths = image_paths
        self.max_nits = 4000

    def transform(self, image):  
        # Transform to tensor
        image = image.astype('float32')/(2**16 - 1)
        image = TF.to_tensor(image).to(torch.float32)
        
        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
        if random.random() > 0.5:
            image = TF.vflip(image)
        
        gamma = random.random() + 1
        image = image**gamma
        
        division = (image.mean(dim=0).max())
        
        division = division if division > 0 else 1
        
        image = image/(division)
        
        #intensity calibration chosen uniform from [500, 5000] 
        intensity_range = int(random.random()*3999 + 1)

        input_image = torch.clip(intensity_range*image, 0.0, self.max_nits)/self.max_nits
        
        return input_image

    def __getitem__(self, index):
        image = tiff.imread(self.image_paths[index])
        x = self.transform(image)
        return x

    def __len__(self):
        return len(self.image_paths)

In [None]:
class ValDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, train=True):
        self.image_paths = image_paths
        
        self.intensity_list = [3093, 492, 2085, 1804, 2985, 2407, 2966, 1376, 596, 3427, 3177, 1176, 1360, 2935, 1325, 1401, 1251, 1157, 394, 2190, 2943, 179, 2746, 2603, 3724, 3617, 3352, 986, 2739, 958, 3428, 231, 2969, 3593, 2770, 3186, 1586, 3807, 1692, 1193, 1966, 2743, 3579, 1360, 2684, 3768, 1705, 2697, 921, 905, 600, 1199, 1871, 2931, 1694, 1442, 2963, 951, 1901, 3628, 1140, 2205, 3405, 2036, 2029, 2135, 587, 178, 305, 2707, 173, 942, 854, 2869, 2001, 890, 1730, 1181, 1218, 1973, 2229, 3709, 3489, 1296, 255, 1720, 761, 610, 2580, 2776, 2813, 1066, 992, 353, 1409, 3146, 964, 1918, 786, 3661, 1491, 1367]
        #maximum display intensity
        self.max_nits = 4000

    def transform(self, image, intensity_range):
        # Transform to tensor
        image = image.astype('float32')/(2**16 - 1)
        image = TF.to_tensor(image).to(torch.float32)
        
        division = (image.mean(dim=0).max())
        
        division = division if division > 0 else 1
        
        image = image/(division)
        
        
        image = torch.clip(intensity_range*image, 0.0, self.max_nits)/self.max_nits
        
        return image

    def __getitem__(self, index):
        image = tiff.imread(self.image_paths[index])
        
        x = self.transform(image, self.intensity_list[index])
        return x

    def __len__(self):
        return len(self.image_paths)

In [None]:
train_data = glob.glob('train/*.tiff')
val_data = glob.glob('val/*.tiff')


#shuffle train_data
np.random.shuffle(train_data)


#crop train_data
#train_data = train_data

train_dataset = Dataset(sorted(train_data))
val_dataset = ValDataset(sorted(val_data))

batch_size = 1
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
)

In [None]:
len(train_dataloader), len(train_dataset)

In [None]:
len(val_dataloader), len(val_dataset)

In [None]:
def show_image(original, img_h=1080, img_w=1920, std_mean = False):
    fig, axs = plt.subplots(1, 1, figsize=(14, 14), constrained_layout=True)
    original = original.permute(1, 2, 0).numpy()
    if std_mean:
        original = original*std + mean
    axs.imshow(original)
    axs.set_title('original image')
    axs.grid(False)

In [None]:
inputs = next(iter(train_dataloader))

print("InputsTensor", inputs.shape)
print(f"MinValue {inputs.min().data:.4f}", f"MaxValue: {inputs.max():.4f}", f"MeanValue: {inputs.mean():.2f}")
#print("TargetsTensor", targets.shape)
#print(f"MinValue: {targets.min():.4f}", f"MaxValue: {targets.max():.4f}", f"MeanValue: {targets.mean():.2f}")
#show_image(targets[0])
show_image(inputs[0])

In [None]:
import io
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sb

PSF_h = 1025
PSF_w = 1025

file = open("zerman_led_psf.csv")
PSF = torch.tensor(np.loadtxt(file, delimiter=",")).reshape(1, 1, PSF_h, PSF_w).to(device)
print(f"Mean of PSF:{PSF.mean()}")
sb.displot(PSF.cpu().numpy().flatten())
plt.xlim(-0.05, 0.2)

In [None]:
import json
 
# Opening JSON file
f = open('Zermans_LEDS_POSITION_2202.json')
 
# returns JSON object as
# a dictionary
data = json.load(f)
b_mask_coords = np.array(data['grid_coords'])
b_mask = torch.zeros(1, 1, 1080, 1920).to(device)
for tup in b_mask_coords:
    b_mask[0, 0, tup[0] - 1, tup[1] - 1] = 1.0

In [None]:
#here we have all leds on
b_psf_full = torch.ones(1, 1, 1080, 1920).type(torch.DoubleTensor).to(device)
b_psf_proj = b_mask*b_psf_full
b_psf_proj.mean()*(1920*1080)/(2202), b_psf_proj.sum()

In [None]:
#flip PSF for 180 degrees
PSF_flipped = torch.flip(PSF, dims=[-1, -2])
#simulation of luminance with all leds on
#PSF_flipped = torch.nn.functional.pad(PSF_flipped, (0,1,0,1), mode='constant')
b_psf_output = fft_conv(b_psf_proj, PSF_flipped, padding=512)

#here we use 24 because b_psf_output have 24.5(beta)
#now 18.5
loss_function_mask = b_psf_output > 0.0
loss_function_mask = loss_function_mask.expand(batch_size, 3, 1080, 1920)

In [None]:
from python_files.utils import get_transmittance

In [None]:
def reconstruct_original_image(B_: torch.tensor, I_or: torch.tensor) -> torch.tensor:
    B = b_mask*B_
    
    eps = 1e-15
    
    B_lbd = fft_conv(180*B, PSF_flipped, padding=512)
    B_lbd = torch.clip(B_lbd, 0.0, B_lbd.max().item())
    B_lbd[B_lbd == 0] = eps
    LCD = torch.clip((I_or)/B_lbd, 0.0, 1.0)
    #LCD = gamma_correction(LCD)
    T = get_transmittance(LCD, 0.005)

    I_re = T*B_lbd
    
    return I_re

In [None]:
def loss_function(I_, I, preds):
    """I_ -- prediction, I -- target, 
       pa is power_parameter
    """
    #data preprocessing
    IM_ = I_
    IM  = I
    
    # L1 smooth loss
    loss_f = torch.nn.SmoothL1Loss(beta=1.0)
    #loss_f = torch.nn.HuberLoss(delta=512.0)
    
    loss = loss_f(IM_, IM)
    
    B = (b_mask*preds).sum()/(2202)
    
    loss += B*90

    return loss

In [None]:
#make dirs for model saving
save_name = f"{seed}_Zerman_1.0"
curr_dir = f'seed' + save_name 

if not os.path.exists(curr_dir):
    os.makedirs(curr_dir)
    
if not os.path.exists('logs'):
    os.makedirs('logs')

In [None]:
def train_model(model, loss_func, optimizer, scheduler, num_epochs):
    try:
        best_model_wts = copy.deepcopy(model.state_dict())
        best_loss = 1000.0
        psnr = 0.0
        msssim = 0.0
                
        train_loss_history = []
        train_ms_ssim_history = []
        train_psnr_history = []
        
        val_loss_history = []
        val_ms_ssim_history = []       
        val_psnr_history = []
        
        
        for epoch in range(num_epochs):
            for phase in ['train', 'val']:
                #training dataset
                if phase == 'train': 
                    dataloader = train_dataloader
                    model.train()
                #validation dataset
                else: 
                    dataloader = val_dataloader
                    model.eval()
                running_loss = 0.
                running_pu_msssim = 0.
                running_pu_psnr = 0.
                # Iterate over data.
                for i, inputs in enumerate(dataloader):
                    inputs = inputs.to(device)
                    
                    #optimizer
                    optimizer.zero_grad()
                    
                    #forward and backward
                    with torch.set_grad_enabled(phase == 'train'):
                        #TF.normalize(image, mean, std)
                        preds = model(inputs)

                        constructed_preds = reconstruct_original_image(preds, 4000*inputs)
                        
                        loss_value = loss_func(4000*inputs, constructed_preds, preds)
                        #training without metrics, cos spending to much time to evaluate it
                        #pu_msssim = PU_MSSSIM(4000*inputs, 4000*constructed_preds)
                        #pu_psnr = PU_PSNR(4000*inputs, 4000*constructed_preds)
                        pu_msssim = torch.ones(1)
                        pu_psnr = torch.ones(1)
                        
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss_value.backward()
                            optimizer.step()
                            #scheduler.step()

                    # statistics
                    running_loss += loss_value.item()
                    running_pu_msssim += pu_msssim.item()
                    running_pu_psnr += pu_psnr.item()
                    print("epoch: {}/{} nested {}/{} {}  - loss: {:.4f}, PU_MSSSIM: {:.2f}, PU_PSNR {:.1f}" \
                          .format(epoch, num_epochs, (i+1), len(dataloader), 
                                  phase, running_loss/(i+1), running_pu_msssim/(i+1), running_pu_psnr/(i+1)), end='\r')
                #learning rate decrease
                if phase == 'train':
                    scheduler.step()
                #mean_value on epoch
                epoch_loss = running_loss / len(dataloader)
                epoch_pu_msssim = running_pu_msssim/ len(dataloader)
                epoch_pu_psnr = running_pu_psnr/ len(dataloader)
                
                #saving best model via loss(don't care about it at the moment)
                if phase == 'val' and epoch_pu_msssim > msssim:
                    best_loss = epoch_loss
                    msssim = epoch_pu_msssim
                    psnr = epoch_pu_psnr
                    best_model_wts = copy.deepcopy(model.state_dict())
                elif phase == 'train':
                    train_loss_history.append(epoch_loss)
                    train_ms_ssim_history.append(epoch_pu_msssim)
                    train_psnr_history.append(epoch_pu_psnr)
                if phase == 'val':
                    val_loss_history.append(epoch_loss)
                    val_ms_ssim_history.append(epoch_pu_msssim)
                    val_psnr_history.append(epoch_pu_psnr)
                    #model saving
                    checkpoint = {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()
                    }
                    save_ckp(checkpoint, epoch, is_best=False)
                    
                    #visualisation in time and picture of loss saving in logs
                    clear_output(True)
                    plt.figure(figsize=(12, 5))
                    plt.plot(train_loss_history, label='train_loss')
                    plt.plot(val_loss_history, label='val_loss')
                    #plt.ylim((0, 20))
                    plt.legend()
                    plt.grid()
                    plt.savefig(f'logs/{save_name}.png', bbox_inches='tight', dpi=200)
                    plt.show()
                    
                output_str = '{}. {} Loss: {:.4f}, PU_MSSSIM: {:.2f}, PU_PSNR: {:.2f}'.format(
                    epoch, phase, epoch_loss, epoch_pu_msssim, epoch_pu_psnr)
                print(output_str + 70*' ', flush=True)
                
                with open(f"logs/{save_name}.txt", "a+") as f:
                    f.write(output_str + '\n')
    except KeyboardInterrupt as e:
        model.load_state_dict(best_model_wts)
        print(f"Returning model saved with best val loss: {best_loss} PU_MSSSIM: {msssim} PSNR: {psnr}")
        return model, train_loss_history, val_loss_history, \
               train_ms_ssim_history, val_ms_ssim_history, \
               train_psnr_history, val_psnr_history
    model.load_state_dict(best_model_wts)
    print(f"Returning model saved with best val loss: {best_loss} PU_MSSSIM: {msssim} PSNR: {psnr}")
    return model, train_loss_history, val_loss_history, \
               train_ms_ssim_history, val_ms_ssim_history, \
               train_psnr_history, val_psnr_history

In [None]:
from python_files.HDRnew import HDRnet

model = HDRnet()
#model = torch.nn.DataParallel(model)
model.to(device)
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=510, eta_min=0, last_epoch=-1, verbose=False)

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable params:{pytorch_total_params}")

In [None]:
model, train_loss_history, val_loss_history, train_ms_ssim_history, \
val_ms_ssim_history, train_psnr_history, val_psnr_history = \
       train_model(model, loss_function, optimizer, scheduler, num_epochs=510)