In [None]:
from __future__ import absolute_import, division, print_function
# https://github.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/blob/main/train.py

import PIL.Image as pil
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch.nn as nn
import torch.optim as optim

# from pytorch_msssim.pytorch_msssim import msssim
# from pytorch_ssim.pytorch_ssim import ssim
# import PerceptualSimilarity.lpips.lpips as lpips

from tqdm import tqdm

# feed_width = 768
# feed_height =  512
# feed_height = 384
# feed_width = 512
feed_height = 576
feed_width = 768

batch_size = 4#0


class bokehDataset(Dataset):
    
    def __init__(self, csv_file,root_dir, transform=None):
        
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.root_dir = root_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        bok = pil.open(self.root_dir + self.data.iloc[idx, 0][1:]).convert('RGB')
        org = pil.open(self.root_dir + self.data.iloc[idx, 1][1:]).convert('RGB')


        depth_path = str(self.root_dir + self.data.iloc[idx, 1][1:]
                         ).replace('original', 'depth').replace('jpg', 'png')
        depth = pil.open(depth_path)

        blur_6_path = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_6_3_2')

        blur_10_path_old = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_10_3_3') # old
        blur_10_path = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_10_3_3')
        blur_15_path = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_15_3_1')
        blur_20_path_old = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_20_3_3')
        blur_30_path = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_30_3_5')
        blur_50_path_old = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_50_2_4')
        blur_65_path = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_65_3_3')
        blur_100_path = str(self.root_dir + self.data.iloc[idx, 0][1:]).replace('bokeh', 'lens_blur_100_3_3')

        blur_6_path = pil.open(blur_6_path).convert('RGB')
        blur_10_path_old = pil.open(blur_10_path_old).convert('RGB')
        blur_10_path = pil.open(blur_10_path).convert('RGB')
        blur_15_path = pil.open(blur_15_path).convert('RGB')
        blur_20_path_old = pil.open(blur_20_path_old).convert('RGB')
        blur_30_path = pil.open(blur_30_path).convert('RGB')
        blur_50_path_old = pil.open(blur_50_path_old).convert('RGB')
        blur_65_path = pil.open(blur_65_path).convert('RGB')
        blur_100_path = pil.open(blur_100_path).convert('RGB')



        bok = bok.resize((feed_width, feed_height), pil.LANCZOS)
        org = org.resize((feed_width, feed_height), pil.LANCZOS)
        depth = depth.resize((feed_width, feed_height), pil.LANCZOS)

        blur_6_path = blur_6_path.resize((feed_width, feed_height), pil.LANCZOS)
        blur_10_path_old = blur_10_path_old.resize((feed_width, feed_height), pil.LANCZOS)
        blur_10_path = blur_10_path.resize((feed_width, feed_height), pil.LANCZOS)
        blur_15_path = blur_15_path.resize((feed_width, feed_height), pil.LANCZOS)
        blur_20_path_old = blur_20_path_old.resize((feed_width, feed_height), pil.LANCZOS)
        blur_30_path = blur_30_path.resize((feed_width, feed_height), pil.LANCZOS)
        blur_50_path_old = blur_50_path_old.resize((feed_width, feed_height), pil.LANCZOS)
        blur_65_path = blur_65_path.resize((feed_width, feed_height), pil.LANCZOS)
        blur_100_path = blur_100_path.resize((feed_width, feed_height), pil.LANCZOS)

        # blur_25 = Image.fromarray(lens_blur(np.array(org), radius=5, components=5, exposure_gamma=4))
        # blur_45 = Image.fromarray(lens_blur(np.array(org), radius=7, components=2, exposure_gamma=4))
        # blur_75 = Image.fromarray(lens_blur(np.array(org), radius=15, components=2, exposure_gamma=4))
        # blur_25 = ten_2_blur(org_np, 3, 2, 2).cuda()
        # blur_45 = ten_2_blur(org_np, 7, 1, 4).cuda()
        # blur_75 = ten_2_blur(org_np, 15, 2, 4).cuda()


        if self.transform : 
            bok_dep = self.transform(bok)
            org_dep = self.transform(org)
            depth_dep = self.transform(depth)

            blur_6_path = self.transform(blur_6_path)
            blur_10_path_old = self.transform(blur_10_path_old)
            blur_10_path = self.transform(blur_10_path)
            blur_15_path = self.transform(blur_15_path)
            blur_20_path_old = self.transform(blur_20_path_old)
            blur_30_path = self.transform(blur_30_path)
            blur_50_path_old = self.transform(blur_50_path_old)
            blur_65_path = self.transform(blur_65_path)
            blur_100_path = self.transform(blur_100_path)
        # depth_dep = depth_dep / depth_dep.max()
        
        stacked_10 = torch.stack([org_dep, blur_6_path, blur_10_path_old, blur_10_path,
            blur_15_path, blur_20_path_old, blur_30_path, blur_50_path_old, blur_65_path, blur_100_path], dim=0)

        return (bok_dep, org_dep, depth_dep, stacked_10)

transform1 = transforms.Compose([
    transforms.ToTensor(),
])


transform2 = transforms.Compose([
    transforms.RandomHorizontalFlip(p=1),
    transforms.ToTensor(),
])


transform3 = transforms.Compose([
    transforms.RandomVerticalFlip(p=1),
    transforms.ToTensor(),
])


trainset1 = bokehDataset(csv_file = '../MegaDepth/Bokeh_Data/train.csv', root_dir = '.',transform = transform1)
trainset2 = bokehDataset(csv_file = '../MegaDepth/Bokeh_Data/train.csv', root_dir = '.',transform = transform2)
trainset3 = bokehDataset(csv_file = '../MegaDepth/Bokeh_Data/train.csv', root_dir = '.',transform = transform3)

trainloader = torch.utils.data.DataLoader(torch.utils.data.ConcatDataset([trainset1,trainset2,trainset3]), batch_size=batch_size,
                                          shuffle=True, num_workers=7)
# trainloader = torch.utils.data.DataLoader(torch.utils.data.ConcatDataset([trainset1]), batch_size=batch_size,
#                                           shuffle=True, num_workers=4)

testset = bokehDataset(csv_file = '../MegaDepth/Bokeh_Data/test.csv',  root_dir = '.', transform = transform1)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=7)

# blur_25_transform = T.GaussianBlur((25, 25), sigma=1)
# blur_45_transform = T.GaussianBlur((45, 45), sigma=3)
# blur_75_transform = T.GaussianBlur((75, 75), sigma=7)