In [1]:
import cv2
import numpy as np
from skimage import color
import os
import numpy as np
import re
import torch
from torch import nn
from torchvision import transforms
from sklearn.linear_model import LinearRegression
from torchvision.io import read_image
from skimage.io import imread
import threading
import time
import pandas as pd
from datetime import datetime

# Evaluation Functions

In [2]:
def calc_deltaE(source, target, color_chart_area):
  source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
  target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
  source = color.rgb2lab(source)
  target = color.rgb2lab(target)
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  delta_e = np.sqrt(np.sum(np.power(source - target, 2), 1))
  return sum(delta_e) / (np.shape(delta_e)[0] - color_chart_area)

def calc_deltaE2000(source, target, color_chart_area):
  source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
  target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
  source = color.rgb2lab(source)
  target = color.rgb2lab(target)
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  deltaE00 = deltaE2000(source, target)
  return sum(deltaE00) / (np.shape(deltaE00)[0] - color_chart_area)


def deltaE2000(Labstd, Labsample):
  kl = 1
  kc = 1
  kh = 1
  Lstd = np.transpose(Labstd[:, 0])
  astd = np.transpose(Labstd[:, 1])
  bstd = np.transpose(Labstd[:, 2])
  Cabstd = np.sqrt(np.power(astd, 2) + np.power(bstd, 2))
  Lsample = np.transpose(Labsample[:, 0])
  asample = np.transpose(Labsample[:, 1])
  bsample = np.transpose(Labsample[:, 2])
  Cabsample = np.sqrt(np.power(asample, 2) + np.power(bsample, 2))
  Cabarithmean = (Cabstd + Cabsample) / 2
  G = 0.5 * (1 - np.sqrt((np.power(Cabarithmean, 7)) / (np.power(
    Cabarithmean, 7) + np.power(25, 7))))
  apstd = (1 + G) * astd
  apsample = (1 + G) * asample
  Cpsample = np.sqrt(np.power(apsample, 2) + np.power(bsample, 2))
  Cpstd = np.sqrt(np.power(apstd, 2) + np.power(bstd, 2))
  Cpprod = (Cpsample * Cpstd)
  zcidx = np.argwhere(Cpprod == 0)
  hpstd = np.arctan2(bstd, apstd)
  hpstd[np.argwhere((np.abs(apstd) + np.abs(bstd)) == 0)] = 0
  hpsample = np.arctan2(bsample, apsample)
  hpsample = hpsample + 2 * np.pi * (hpsample < 0)
  hpsample[np.argwhere((np.abs(apsample) + np.abs(bsample)) == 0)] = 0
  dL = (Lsample - Lstd)
  dC = (Cpsample - Cpstd)
  dhp = (hpsample - hpstd)
  dhp = dhp - 2 * np.pi * (dhp > np.pi)
  dhp = dhp + 2 * np.pi * (dhp < (-np.pi))
  dhp[zcidx] = 0
  dH = 2 * np.sqrt(Cpprod) * np.sin(dhp / 2)
  Lp = (Lsample + Lstd) / 2
  Cp = (Cpstd + Cpsample) / 2
  hp = (hpstd + hpsample) / 2
  hp = hp - (np.abs(hpstd - hpsample) > np.pi) * np.pi
  hp = hp + (hp < 0) * 2 * np.pi
  hp[zcidx] = hpsample[zcidx] + hpstd[zcidx]
  Lpm502 = np.power((Lp - 50), 2)
  Sl = 1 + 0.015 * Lpm502 / np.sqrt(20 + Lpm502)
  Sc = 1 + 0.045 * Cp
  T = 1 - 0.17 * np.cos(hp - np.pi / 6) + 0.24 * np.cos(2 * hp) + \
      0.32 * np.cos(3 * hp + np.pi / 30) \
      - 0.20 * np.cos(4 * hp - 63 * np.pi / 180)
  Sh = 1 + 0.015 * Cp * T
  delthetarad = (30 * np.pi / 180) * np.exp(
    - np.power((180 / np.pi * hp - 275) / 25, 2))
  Rc = 2 * np.sqrt((np.power(Cp, 7)) / (np.power(Cp, 7) + np.power(25, 7)))
  RT = - np.sin(2 * delthetarad) * Rc
  klSl = kl * Sl
  kcSc = kc * Sc
  khSh = kh * Sh
  de00 = np.sqrt(np.power((dL / klSl), 2) + np.power((dC / kcSc), 2) +
                 np.power((dH / khSh), 2) + RT * (dC / kcSc) * (dH / khSh))
  return de00


def calc_mae(source, target, color_chart_area):
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  source_norm = np.sqrt(np.sum(np.power(source, 2), 1))
  target_norm = np.sqrt(np.sum(np.power(target, 2), 1))
  norm = source_norm * target_norm
  L = np.shape(norm)[0]
  inds = norm != 0
  angles = np.sum(source[inds, :] * target[inds, :], 1) / norm[inds]
  angles[angles > 1] = 1
  f = np.arccos(angles)
  f[np.isnan(f)] = 0
  f = f * 180 / np.pi
  return sum(f) / (L - color_chart_area)



def calc_mse(source, target, color_chart_area):
  source = np.reshape(source, [-1, 1]).astype(np.float64)
  target = np.reshape(target, [-1, 1]).astype(np.float64)
  mse = sum(np.power((source - target), 2))
  return mse / ((np.shape(source)[
    0]) - color_chart_area)


def evaluate_cc(corrected, gt, color_chart_area, opt=1):
  """
    Color constancy (white-balance correction) evaluation of a given corrected
    image.
    :param corrected: corrected image
    :param gt: ground-truth image
    :param color_chart_area: If there is a color chart in the image, that is
     masked out from both images, this variable represents the number of pixels
     of the color chart.
    :param opt: determines the required error metric(s) to be reported.
         Options:
           opt = 1 delta E 2000 (default).
           opt = 2 delta E 2000 and mean squared error (MSE)
           opt = 3 delta E 2000, MSE, and mean angular eror (MAE)
           opt = 4 delta E 2000, MSE, MAE, and delta E 76
    :return: error(s) between corrected and gt images
    """

  if opt == 1:
    return calc_deltaE2000(corrected, gt, color_chart_area)
  elif opt == 2:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area)
  elif opt == 3:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area), calc_mae(corrected, gt,
                                                 color_chart_area)
  elif opt == 4:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area), calc_mae(
      corrected, gt, color_chart_area), calc_deltaE(corrected, gt,
                                                    color_chart_area)
  else:
    raise Exception('Error in evaluate_cc function')




def get_metadata(fileName, set, metadata_baseDir=''):
    """
    Gets metadata (e.g., ground-truth file name, chart coordinates and area).
    :param fileName: input filename
    :param set: which dataset?--options includes: 'RenderedWB_Set1',
      'RenderedWB_Set2', 'Rendered_Cube+'
    :param metadata_baseDir: metadata directory (required for Set1 only)
    :return: metadata for a given image
    evaluation_examples.py provides some examples of how to use it
    """

    fname, file_extension = os.path.splitext(fileName)  # get file parts
    name = os.path.basename(fname)  # get only filename without the directory

    if set == 'RenderedWB_Set1': # Rendered WB dataset (Set1)
        metadatafile_color = name + '_color.txt' # chart's colors info.
        metadatafile_mask = name + '_mask.txt' # chart's coordinate info.
        # get color info.
        f = open(os.path.join(metadata_baseDir, metadatafile_color), 'r')
        C = f.read()
        colors = np.zeros((3, 24))  # color chart colors
        temp = re.split(',|\n', C)
        # 3 x 24 colors in the color chart
        colors = np.reshape(np.asfarray(temp[:-1], float), (24, 3)).transpose()
        # get coordinate info
        f = open(os.path.join(metadata_baseDir, metadatafile_mask), 'r')
        C = f.read()
        temp = re.split(',|\n', C)
        # take only the first 4 elements (i.e., the color chart coordinates)
        temp = temp[0:4]
        mask = np.asfarray(temp, float)  # color chart mask coordinates
        # get ground-truth file name
        seperator = '_'
        temp = name.split(seperator)
        gt_file = seperator.join(temp[:-2])
        gt_file = gt_file + '_G_AS.png'
        # compute mask area
        mask_area = mask[2] * mask[3]
        # final metadata
        data = {"gt_filename": gt_file, "cc_colors": colors, "cc_mask": mask,
                "cc_mask_area": mask_area}

    elif set == 'RenderedWB_Set2': # Rendered WB dataset (Set2)
        data = {"gt_filename": name + file_extension, "cc_colors": None,
                "cc_mask": None, "cc_mask_area": 0}

    elif set == 'Rendered_Cube+': # Rendered Cube+
        # get ground-truth filename
        temp = name.split('_')
        gt_file = temp[0] + file_extension
        mask_area = 58373  # calibration obj's area is fixed over all images
        data = {"gt_filename": gt_file, "cc_colors": None, "cc_mask": None,
                "cc_mask_area": mask_area}
    else:
        raise Exception(
            "Invalid value for set variable. " +
            "Please use: 'RenderedWB_Set1', 'RenderedWB_Set2', 'Rendered_Cube+'")

    return data


# Prepare data

In [3]:
CHECKPOINTS = './checkpoints'
IN_DIR = 'E:/datasets/WB/Set2/Set2_input_images'
GT_DIR = 'E:/datasets/WB/Set2/Set2_ground_truth_images/'
FOLD_DIR = ''
METADATA_DIR = ''

# Deep learning model

## Simple Color Mapping Network 

In [4]:
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        # number of input channels is a number of filters in the previous layer
        # number of output channels is a number of filters in the current layer
        # "same" convolutions
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class Kernel(nn.Module):
  def forward(self, x):
    r = x[:,0,:,:]
    g = x[:,1,:,:]
    b = x[:,2,:,:]
    return torch.stack([r,g,b,r**2,g**2,b**2,r*g,g*b,r*b,r*g*b,torch.ones_like(r)], dim=1)

class SCMN(nn.Module):
  def __init__(self, in_channels=3):
    super().__init__()
    self.Conv1 = ConvBlock(in_channels, 24)
    self.Conv2 = ConvBlock(24, 48)
    self.Conv3 = ConvBlock(48, 96)
    self.Conv4 = ConvBlock(96, 192)
    self.Conv5 = ConvBlock(192, 384)
    self.MaxPool = nn.MaxPool2d(2, 2)
    self.Flatten = nn.Flatten()
    self.Linear = nn.Linear(24576, 33)
    self.Kernel = Kernel()

  # Infer mode
  def forward(self, x):
    x2 = self.MaxPool(self.Conv1(x))
    x2 = self.MaxPool(self.Conv2(x2))
    x2 = self.MaxPool(self.Conv3(x2))
    x2 = self.MaxPool(self.Conv4(x2))
    x2 = self.MaxPool(self.Conv5(x2))
    x2 = self.Flatten(x2)
    x2 = self.Linear(x2)
    x2 = torch.reshape(x2, (-1, 3, 11)) # B x 3 x 11

    # x3 = self.Kernel(x) # B x 11 x 256 x 256
    # x3 = torch.reshape(x3, (-1, 11, 256 * 256)) # B x 11 x 65536
    
    # x4 = torch.bmm(x2, x3) # B x 3 x 65536
    # x5 = torch.reshape(x4, (-1, 3, 256, 256))

    return x2

# Inference model

## DL model

In [5]:
mytransform = transforms.Compose([
    transforms.Resize((256, 256))
])

device = 'cuda'

model_checkpoint_path = os.path.join(CHECKPOINTS, 'scmn_20220925_111650_150.pth')
net = SCMN()
net.load_state_dict(torch.load(model_checkpoint_path))
net.eval()
net = net.to(device)

def outOfGamutClipping(I):
    """ Clips out-of-gamut pixels. """
    I[I > 1] = 1  # any pixel is higher than 1, clip it to 1
    I[I < 0] = 0  # any pixel is below 0, clip it to 0
    return I

def get_mapping_matrix(img):
  img = torch.tensor(img, dtype=torch.float32)
  img = torch.permute(img, (2, 0, 1))
  img = torch.unsqueeze(img, dim=0)
  img = img.to('cuda')
  with torch.no_grad():
    img = mytransform(img)
    out = net(img) # 1 x 3 x 11
  out = torch.squeeze(out) # 3 x 11
  mapping_matrix = out.cpu().numpy()
  return mapping_matrix

def kernel(img):
  r = img[:,:,0]
  g = img[:,:,1]
  b = img[:,:,2]
  return np.stack([r,g,b,r**2,g**2,b**2,r*g,g*b,r*b,r*g*b, np.ones_like(r)], axis=0) # 11 x height x width

def infer(img):
  M = get_mapping_matrix(img) # 3 x 11
  height, width, _ = img.shape
  img = kernel(img) # 11 x height x width
  img = img.reshape(11, -1) # (11 x width*height)
  img = np.matmul(M, img) # 3 x width * height
  img = img.reshape(3, height, width) # 3 x height x width
  img = np.transpose(img, (1,2,0)) # height x width x 3
  img = outOfGamutClipping(img)
  return img

# Evaluate

In [6]:
def eval(name, imgin, gtimg):
    # read the image
    I_in = cv2.imread(imgin, cv2.IMREAD_COLOR)
    # read gt image
    gt = cv2.imread(gtimg, cv2.IMREAD_COLOR)
    # metadata
    metadata = get_metadata(name, 'RenderedWB_Set2', METADATA_DIR)  

    # white balance I_in
    I_corr = (infer(I_in[:,:,::-1] / 255)[:,:,::-1]*255).astype("uint8")  

    # import matplotlib.pyplot as plt
    # plt.figure(figsize=(20,10))
    # plt.subplot(1,3,1)
    # plt.imshow(I_in[:,:,::-1])
    # plt.title("Input")
    # plt.subplot(1,3,2)
    # plt.imshow(gt[:,:,::-1])
    # plt.title("Target")
    # plt.subplot(1,3,3)
    # plt.imshow(I_corr[:,:,::-1])
    # plt.title("Corrected")
    # plt.show()

    # Evaluation
    deltaE00, MSE, MAE, deltaE76 = evaluate_cc(I_corr, gt, metadata["cc_mask_area"],
                                            opt=4)

    return deltaE00, MSE.item(), MAE, deltaE76

In [7]:
eval_log_rows = []

class EvalThread(threading.Thread):
  def __init__(self, part, n_part, imgs):
    super().__init__()
    self.part = part
    self.n_part = n_part
    self.imgs = imgs

  def run(self):
    imgs = self.imgs
    for i in range(len(imgs)):
      if i % self.n_part == self.part:
        img_in = os.path.join(IN_DIR, imgs[i] + ".png")
        img_gt = os.path.join(GT_DIR, imgs[i] + ".png")
        deltaE00, MSE, MAE, deltaE76 = eval(imgs[i], img_in, img_gt)
        eval_result = {
          "img_in": img_in, 
          "img_gt": img_gt,
          "deltaE2000": deltaE00,
          "MSE": MSE,
          "MAE": MAE,
          "deltaE76": deltaE76
        }
        eval_log_rows.append(eval_result)
        print(f'[{len(eval_log_rows)}/{len(imgs)}] {eval_result}')


In [8]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

imgs = [f.split(".")[0] for f in os.listdir(IN_DIR)][:20]

n_part = 64
threads = []
for part in range(n_part):
  thread = EvalThread(part, n_part, imgs)
  threads.append(thread)

for thread in threads:
  thread.start()
  thread.join()


[1/20] {'img_in': 'E:/datasets/WB/Set2/Set2_input_images\\DSLR_00001.png', 'img_gt': 'E:/datasets/WB/Set2/Set2_ground_truth_images/DSLR_00001.png', 'deltaE2000': 3.0928288515443954, 'MSE': 41.8749584817752, 'MAE': 1.2412885141889556, 'deltaE76': 3.512461613777923}
[2/20] {'img_in': 'E:/datasets/WB/Set2/Set2_input_images\\DSLR_00002.png', 'img_gt': 'E:/datasets/WB/Set2/Set2_ground_truth_images/DSLR_00002.png', 'deltaE2000': 5.019537534141833, 'MSE': 74.7032992890219, 'MAE': 3.710484897398741, 'deltaE76': 6.0250875314010495}
[3/20] {'img_in': 'E:/datasets/WB/Set2/Set2_input_images\\DSLR_00003.png', 'img_gt': 'E:/datasets/WB/Set2/Set2_ground_truth_images/DSLR_00003.png', 'deltaE2000': 3.2750902914831665, 'MSE': 6.81635464538635, 'MAE': 1.1388481243872635, 'deltaE76': 2.7469896036068397}
[4/20] {'img_in': 'E:/datasets/WB/Set2/Set2_input_images\\DSLR_00004.png', 'img_gt': 'E:/datasets/WB/Set2/Set2_ground_truth_images/DSLR_00004.png', 'deltaE2000': 2.978803627246875, 'MSE': 52.51178134461053

In [9]:
EVALUATION_SAVE_PATH = './evals'
FILE_NAME = f'evaluation_results_{timestamp}.csv'

fp = os.path.join(EVALUATION_SAVE_PATH, FILE_NAME)
eval_log = pd.DataFrame(eval_log_rows)
eval_log.to_csv(fp)
print("Saved evaluation results to ", fp)

Saved evaluation results to  ./evals\evaluation_results_20220929_102326.csv


In [10]:
eval_log

Unnamed: 0,img_in,img_gt,deltaE2000,MSE,MAE,deltaE76
0,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,3.092829,41.874958,1.241289,3.512462
1,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,5.019538,74.703299,3.710485,6.025088
2,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,3.27509,6.816355,1.138848,2.74699
3,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,2.978804,52.511781,2.527451,5.170905
4,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,3.538915,15.715209,2.422054,4.27662
5,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,3.774003,106.356884,2.777092,6.397645
6,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,2.516018,45.31842,2.115884,4.581574
7,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,1.068647,4.164344,1.211112,1.346874
8,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,0.974998,4.033103,1.111978,1.238902
9,E:/datasets/WB/Set2/Set2_input_images\DSLR_000...,E:/datasets/WB/Set2/Set2_ground_truth_images/D...,2.323552,13.300489,2.978194,2.577173


# Calculate quantile

In [None]:
df = pd.read_csv('/content/drive/MyDrive/CC/evals/evaluation_results_20220928_014845.csv', index_col=0)

In [None]:
df

Unnamed: 0,img_in,img_gt,deltaE2000,MSE,MAE,deltaE76
0,/content/Set1_input_images_wo_CC_JPG/NikonD40_...,/content/Set1_ground_truth_images_wo_CC/NikonD...,5.836537,137.404310,8.516401,6.401685
1,/content/Set1_input_images_wo_CC_JPG/Canon600D...,/content/Set1_ground_truth_images_wo_CC/Canon6...,4.462629,108.881064,1.665026,5.035117
2,/content/Set1_input_images_wo_CC_JPG/8D5U5581_...,/content/Set1_ground_truth_images_wo_CC/8D5U55...,4.320146,83.158145,2.064598,5.044472
3,/content/Set1_input_images_wo_CC_JPG/FujifilmX...,/content/Set1_ground_truth_images_wo_CC/Fujifi...,2.897678,17.187312,5.256819,2.765034
4,/content/Set1_input_images_wo_CC_JPG/8D5U5597_...,/content/Set1_ground_truth_images_wo_CC/8D5U55...,14.370017,342.307826,13.440953,18.098631
...,...,...,...,...,...,...
21041,/content/Set1_input_images_wo_CC_JPG/NikonD40_...,/content/Set1_ground_truth_images_wo_CC/NikonD...,5.478602,122.109472,4.631952,6.484465
21042,/content/Set1_input_images_wo_CC_JPG/Canon1DsM...,/content/Set1_ground_truth_images_wo_CC/Canon1...,3.589185,54.841123,3.152034,4.946736
21043,/content/Set1_input_images_wo_CC_JPG/IMG_0485_...,/content/Set1_ground_truth_images_wo_CC/IMG_04...,6.856161,111.589565,7.745380,8.961823
21044,/content/Set1_input_images_wo_CC_JPG/NikonD40_...,/content/Set1_ground_truth_images_wo_CC/NikonD...,2.665912,18.645326,2.157867,2.925228


In [None]:
df.quantile([0.25, 0.5, 0.75])

Unnamed: 0,deltaE2000,MSE,MAE,deltaE76
0.25,3.074075,30.196393,2.663184,3.663297
0.5,3.861929,53.504668,3.478681,4.821954
0.75,4.81512,94.095673,4.419786,6.167367


In [None]:
df.mean()

  """Entry point for launching an IPython kernel.


deltaE2000     4.093920
MSE           75.580025
MAE            3.709035
deltaE76       5.125629
dtype: float64