In [None]:
import cv2
import numpy as np
from skimage import color
import os
import numpy as np
import re
import torch
from torch import nn
from torchvision import transforms
from sklearn.linear_model import LinearRegression
from torchvision.io import read_image
from skimage.io import imread
import threading
import time
import pandas as pd

# Evaluation Functions

In [None]:
def calc_deltaE(source, target, color_chart_area):
  source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
  target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
  source = color.rgb2lab(source)
  target = color.rgb2lab(target)
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  delta_e = np.sqrt(np.sum(np.power(source - target, 2), 1))
  return sum(delta_e) / (np.shape(delta_e)[0] - color_chart_area)

def calc_deltaE2000(source, target, color_chart_area):
  source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
  target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
  source = color.rgb2lab(source)
  target = color.rgb2lab(target)
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  deltaE00 = deltaE2000(source, target)
  return sum(deltaE00) / (np.shape(deltaE00)[0] - color_chart_area)


def deltaE2000(Labstd, Labsample):
  kl = 1
  kc = 1
  kh = 1
  Lstd = np.transpose(Labstd[:, 0])
  astd = np.transpose(Labstd[:, 1])
  bstd = np.transpose(Labstd[:, 2])
  Cabstd = np.sqrt(np.power(astd, 2) + np.power(bstd, 2))
  Lsample = np.transpose(Labsample[:, 0])
  asample = np.transpose(Labsample[:, 1])
  bsample = np.transpose(Labsample[:, 2])
  Cabsample = np.sqrt(np.power(asample, 2) + np.power(bsample, 2))
  Cabarithmean = (Cabstd + Cabsample) / 2
  G = 0.5 * (1 - np.sqrt((np.power(Cabarithmean, 7)) / (np.power(
    Cabarithmean, 7) + np.power(25, 7))))
  apstd = (1 + G) * astd
  apsample = (1 + G) * asample
  Cpsample = np.sqrt(np.power(apsample, 2) + np.power(bsample, 2))
  Cpstd = np.sqrt(np.power(apstd, 2) + np.power(bstd, 2))
  Cpprod = (Cpsample * Cpstd)
  zcidx = np.argwhere(Cpprod == 0)
  hpstd = np.arctan2(bstd, apstd)
  hpstd[np.argwhere((np.abs(apstd) + np.abs(bstd)) == 0)] = 0
  hpsample = np.arctan2(bsample, apsample)
  hpsample = hpsample + 2 * np.pi * (hpsample < 0)
  hpsample[np.argwhere((np.abs(apsample) + np.abs(bsample)) == 0)] = 0
  dL = (Lsample - Lstd)
  dC = (Cpsample - Cpstd)
  dhp = (hpsample - hpstd)
  dhp = dhp - 2 * np.pi * (dhp > np.pi)
  dhp = dhp + 2 * np.pi * (dhp < (-np.pi))
  dhp[zcidx] = 0
  dH = 2 * np.sqrt(Cpprod) * np.sin(dhp / 2)
  Lp = (Lsample + Lstd) / 2
  Cp = (Cpstd + Cpsample) / 2
  hp = (hpstd + hpsample) / 2
  hp = hp - (np.abs(hpstd - hpsample) > np.pi) * np.pi
  hp = hp + (hp < 0) * 2 * np.pi
  hp[zcidx] = hpsample[zcidx] + hpstd[zcidx]
  Lpm502 = np.power((Lp - 50), 2)
  Sl = 1 + 0.015 * Lpm502 / np.sqrt(20 + Lpm502)
  Sc = 1 + 0.045 * Cp
  T = 1 - 0.17 * np.cos(hp - np.pi / 6) + 0.24 * np.cos(2 * hp) + \
      0.32 * np.cos(3 * hp + np.pi / 30) \
      - 0.20 * np.cos(4 * hp - 63 * np.pi / 180)
  Sh = 1 + 0.015 * Cp * T
  delthetarad = (30 * np.pi / 180) * np.exp(
    - np.power((180 / np.pi * hp - 275) / 25, 2))
  Rc = 2 * np.sqrt((np.power(Cp, 7)) / (np.power(Cp, 7) + np.power(25, 7)))
  RT = - np.sin(2 * delthetarad) * Rc
  klSl = kl * Sl
  kcSc = kc * Sc
  khSh = kh * Sh
  de00 = np.sqrt(np.power((dL / klSl), 2) + np.power((dC / kcSc), 2) +
                 np.power((dH / khSh), 2) + RT * (dC / kcSc) * (dH / khSh))
  return de00


def calc_mae(source, target, color_chart_area):
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  source_norm = np.sqrt(np.sum(np.power(source, 2), 1))
  target_norm = np.sqrt(np.sum(np.power(target, 2), 1))
  norm = source_norm * target_norm
  L = np.shape(norm)[0]
  inds = norm != 0
  angles = np.sum(source[inds, :] * target[inds, :], 1) / norm[inds]
  angles[angles > 1] = 1
  f = np.arccos(angles)
  f[np.isnan(f)] = 0
  f = f * 180 / np.pi
  return sum(f) / (L - color_chart_area)



def calc_mse(source, target, color_chart_area):
  source = np.reshape(source, [-1, 1]).astype(np.float64)
  target = np.reshape(target, [-1, 1]).astype(np.float64)
  mse = sum(np.power((source - target), 2))
  return mse / ((np.shape(source)[
    0]) - color_chart_area)


def evaluate_cc(corrected, gt, color_chart_area, opt=1):
  """
    Color constancy (white-balance correction) evaluation of a given corrected
    image.
    :param corrected: corrected image
    :param gt: ground-truth image
    :param color_chart_area: If there is a color chart in the image, that is
     masked out from both images, this variable represents the number of pixels
     of the color chart.
    :param opt: determines the required error metric(s) to be reported.
         Options:
           opt = 1 delta E 2000 (default).
           opt = 2 delta E 2000 and mean squared error (MSE)
           opt = 3 delta E 2000, MSE, and mean angular eror (MAE)
           opt = 4 delta E 2000, MSE, MAE, and delta E 76
    :return: error(s) between corrected and gt images
    """

  if opt == 1:
    return calc_deltaE2000(corrected, gt, color_chart_area)
  elif opt == 2:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area)
  elif opt == 3:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area), calc_mae(corrected, gt,
                                                 color_chart_area)
  elif opt == 4:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area), calc_mae(
      corrected, gt, color_chart_area), calc_deltaE(corrected, gt,
                                                    color_chart_area)
  else:
    raise Exception('Error in evaluate_cc function')




def get_metadata(fileName, set, metadata_baseDir=''):
    """
    Gets metadata (e.g., ground-truth file name, chart coordinates and area).
    :param fileName: input filename
    :param set: which dataset?--options includes: 'RenderedWB_Set1',
      'RenderedWB_Set2', 'Rendered_Cube+'
    :param metadata_baseDir: metadata directory (required for Set1 only)
    :return: metadata for a given image
    evaluation_examples.py provides some examples of how to use it
    """

    fname, file_extension = os.path.splitext(fileName)  # get file parts
    name = os.path.basename(fname)  # get only filename without the directory

    if set == 'RenderedWB_Set1': # Rendered WB dataset (Set1)
        metadatafile_color = name + '_color.txt' # chart's colors info.
        metadatafile_mask = name + '_mask.txt' # chart's coordinate info.
        # get color info.
        f = open(os.path.join(metadata_baseDir, metadatafile_color), 'r')
        C = f.read()
        colors = np.zeros((3, 24))  # color chart colors
        temp = re.split(',|\n', C)
        # 3 x 24 colors in the color chart
        colors = np.reshape(np.asfarray(temp[:-1], float), (24, 3)).transpose()
        # get coordinate info
        f = open(os.path.join(metadata_baseDir, metadatafile_mask), 'r')
        C = f.read()
        temp = re.split(',|\n', C)
        # take only the first 4 elements (i.e., the color chart coordinates)
        temp = temp[0:4]
        mask = np.asfarray(temp, float)  # color chart mask coordinates
        # get ground-truth file name
        seperator = '_'
        temp = name.split(seperator)
        gt_file = seperator.join(temp[:-2])
        gt_file = gt_file + '_G_AS.png'
        # compute mask area
        mask_area = mask[2] * mask[3]
        # final metadata
        data = {"gt_filename": gt_file, "cc_colors": colors, "cc_mask": mask,
                "cc_mask_area": mask_area}

    elif set == 'RenderedWB_Set2': # Rendered WB dataset (Set2)
        data = {"gt_filename": name + file_extension, "cc_colors": None,
                "cc_mask": None, "cc_mask_area": 0}

    elif set == 'Rendered_Cube+': # Rendered Cube+
        # get ground-truth filename
        temp = name.split('_')
        gt_file = temp[0] + file_extension
        mask_area = 58373  # calibration obj's area is fixed over all images
        data = {"gt_filename": gt_file, "cc_colors": None, "cc_mask": None,
                "cc_mask_area": mask_area}
    else:
        raise Exception(
            "Invalid value for set variable. " +
            "Please use: 'RenderedWB_Set1', 'RenderedWB_Set2', 'Rendered_Cube+'")

    return data


# Prepare data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!unzip -q /content/drive/MyDrive/CC/datasets/Set1_ground_truth_images_wo_CC.zip -d Set1_ground_truth_images_wo_CC
!unzip -q /content/drive/MyDrive/CC/datasets/Set1_input_images_wo_CC_JPG.zip -d Set1_input_images_wo_CC_JPG
!unzip -q /content/drive/MyDrive/CC/datasets/Set1_input_images_metadata.zip -d Set1_input_images_metadata

In [None]:
CHECKPOINTS = '/content/drive/MyDrive/CC/checkpoints'
IN_DIR = '/content/Set1_input_images_wo_CC_JPG'
GT_DIR = '/content/Set1_ground_truth_images_wo_CC'
FOLD_DIR = '/content/drive/MyDrive/CC/datasets/Set1_folds'
METADATA_DIR = '/content/Set1_input_images_metadata'

In [None]:
!ls $CHECKPOINTS

attention_unet_20220918_021432_10.pth  a-unet_aug_20220918_104239_30.pth
attention_unet_20220918_021432_15.pth  a-unet_aug_20220918_104239_35.pth
attention_unet_20220918_021432_20.pth  a-unet_aug_20220918_104239_40.pth
attention_unet_20220918_021432_25.pth  a-unet_aug_20220918_104239_45.pth
attention_unet_20220918_021432_30.pth  a-unet_aug_20220918_104239_50.pth
a-unet_aug_20220918_104239_100.pth     a-unet_aug_20220918_104239_55.pth
a-unet_aug_20220918_104239_105.pth     a-unet_aug_20220918_104239_5.pth
a-unet_aug_20220918_104239_10.pth      a-unet_aug_20220918_104239_60.pth
a-unet_aug_20220918_104239_110.pth     a-unet_aug_20220918_104239_65.pth
a-unet_aug_20220918_104239_115.pth     a-unet_aug_20220918_104239_70.pth
a-unet_aug_20220918_104239_120.pth     a-unet_aug_20220918_104239_75.pth
a-unet_aug_20220918_104239_125.pth     a-unet_aug_20220918_104239_80.pth
a-unet_aug_20220918_104239_130.pth     a-unet_aug_20220918_104239_85.pth
a-unet_aug_20220918_104239_135.pth     a-unet_aug_20

# Deep learning model

## Define Attention U-Net model

In [None]:
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        # number of input channels is a number of filters in the previous layer
        # number of output channels is a number of filters in the current layer
        # "same" convolutions
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class UpConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(UpConv, self).__init__()

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class AttentionBlock(nn.Module):
    """Attention block with learnable parameters"""

    def __init__(self, F_g, F_l, n_coefficients):
        """
        :param F_g: number of feature maps (channels) in previous layer
        :param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection
        :param n_coefficients: number of learnable multi-dimensional attention coefficients
        """
        super(AttentionBlock, self).__init__()

        self.W_gate = nn.Sequential(
            nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_coefficients)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_coefficients)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, gate, skip_connection):
        """
        :param gate: gating signal from previous layer
        :param skip_connection: activation from corresponding encoder layer
        :return: output activations
        """
        g1 = self.W_gate(gate)
        x1 = self.W_x(skip_connection)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        out = skip_connection * psi
        return out


class AttentionUNet(nn.Module):

    def __init__(self, img_ch=3, output_ch=3):
        super(AttentionUNet, self).__init__()

        self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = ConvBlock(img_ch, 24)
        self.Conv2 = ConvBlock(24, 48)
        self.Conv3 = ConvBlock(48, 96)
        self.Conv4 = ConvBlock(96, 192)
        self.Conv5 = ConvBlock(192, 384)

        self.Up5 = UpConv(384, 192)
        self.Att5 = AttentionBlock(F_g=192, F_l=192, n_coefficients=96)
        self.UpConv5 = ConvBlock(384, 192)

        self.Up4 = UpConv(192, 96)
        self.Att4 = AttentionBlock(F_g=96, F_l=96, n_coefficients=48)
        self.UpConv4 = ConvBlock(192, 96)

        self.Up3 = UpConv(96, 48)
        self.Att3 = AttentionBlock(F_g=48, F_l=48, n_coefficients=24)
        self.UpConv3 = ConvBlock(96, 48)

        self.Up2 = UpConv(48, 24)
        self.Att2 = AttentionBlock(F_g=24, F_l=24, n_coefficients=12)
        self.UpConv2 = ConvBlock(48, 24)

        self.Conv = nn.Conv2d(24, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        """
        e : encoder layers
        d : decoder layers
        s : skip-connections from encoder layers to decoder layers
        """
        e1 = self.Conv1(x)

        e2 = self.MaxPool(e1)
        e2 = self.Conv2(e2)

        e3 = self.MaxPool(e2)
        e3 = self.Conv3(e3)

        e4 = self.MaxPool(e3)
        e4 = self.Conv4(e4)

        e5 = self.MaxPool(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)

        s4 = self.Att5(gate=d5, skip_connection=e4)
        d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
        d5 = self.UpConv5(d5)

        d4 = self.Up4(d5)
        s3 = self.Att4(gate=d4, skip_connection=e3)
        d4 = torch.cat((s3, d4), dim=1)
        d4 = self.UpConv4(d4)

        d3 = self.Up3(d4)
        s2 = self.Att3(gate=d3, skip_connection=e2)
        d3 = torch.cat((s2, d3), dim=1)
        d3 = self.UpConv3(d3)

        d2 = self.Up2(d3)
        s1 = self.Att2(gate=d2, skip_connection=e1)
        d2 = torch.cat((s1, d2), dim=1)
        d2 = self.UpConv2(d2)

        out = self.Conv(d2)

        return out


## Simple Attention U-Net

In [None]:
class SimpleAttentionUNet(nn.Module):
    def __init__(self, img_ch=3, output_ch=3):
        super(SimpleAttentionUNet, self).__init__()

        self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = ConvBlock(img_ch, 24)
        self.Conv2 = ConvBlock(24, 48)
        self.Conv3 = ConvBlock(48, 96)
        self.Conv4 = ConvBlock(96, 192)

        self.Up4 = UpConv(192, 96)
        self.Att4 = AttentionBlock(F_g=96, F_l=96, n_coefficients=48)
        self.UpConv4 = ConvBlock(192, 96)

        self.Up3 = UpConv(96, 48)
        self.Att3 = AttentionBlock(F_g=48, F_l=48, n_coefficients=24)
        self.UpConv3 = ConvBlock(96, 48)

        self.Up2 = UpConv(48, 24)
        self.Att2 = AttentionBlock(F_g=24, F_l=24, n_coefficients=12)
        self.UpConv2 = ConvBlock(48, 24)

        self.Conv = nn.Conv2d(24, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        """
        e : encoder layers
        d : decoder layers
        s : skip-connections from encoder layers to decoder layers
        """
        e1 = self.Conv1(x)

        e2 = self.MaxPool(e1)
        e2 = self.Conv2(e2)

        e3 = self.MaxPool(e2)
        e3 = self.Conv3(e3)

        e4 = self.MaxPool(e3)
        e4 = self.Conv4(e4)

        # e5 = self.MaxPool(e4)
        # e5 = self.Conv5(e5)

        # d5 = self.Up5(e5)

        # s4 = self.Att5(gate=d5, skip_connection=e4)
        # d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
        # d5 = self.UpConv5(d5)

        # d4 = self.Up4(d5)
        d4 = self.Up4(e4)
        s3 = self.Att4(gate=d4, skip_connection=e3)
        d4 = torch.cat((s3, d4), dim=1)
        d4 = self.UpConv4(d4)

        d3 = self.Up3(d4)
        s2 = self.Att3(gate=d3, skip_connection=e2)
        d3 = torch.cat((s2, d3), dim=1)
        d3 = self.UpConv3(d3)

        d2 = self.Up2(d3)
        s1 = self.Att2(gate=d2, skip_connection=e1)
        d2 = torch.cat((s1, d2), dim=1)
        d2 = self.UpConv2(d2)

        out = self.Conv(d2)

        return out


# Inference model

## Post process

In [None]:
def kernelP(I):
    """ Kernel function: kernel(r, g, b) -> (r,g,b,rg,rb,gb,r^2,g^2,b^2,rgb,1)
        Ref: Hong, et al., "A study of digital camera colorimetric characterization
         based on polynomial modeling." Color Research & Application, 2001. """
    return (np.transpose((I[:, 0], I[:, 1], I[:, 2], I[:, 0] * I[:, 1], I[:, 0] * I[:, 2],
                          I[:, 1] * I[:, 2], I[:, 0] * I[:, 0], I[:, 1] * I[:, 1],
                          I[:, 2] * I[:, 2], I[:, 0] * I[:, 1] * I[:, 2],
                          np.repeat(1, np.shape(I)[0]))))


def get_mapping_func(image1, image2):
    """ Computes the polynomial mapping """
    image1 = np.reshape(image1, [-1, 3])
    image2 = np.reshape(image2, [-1, 3])
    m = LinearRegression().fit(kernelP(image1), image2)
    return m


def apply_mapping_func(image, m):
    """ Applies the polynomial mapping """
    sz = image.shape
    image = np.reshape(image, [-1, 3])
    result = m.predict(kernelP(image))
    result = np.reshape(result, [sz[0], sz[1], sz[2]])
    return result

def outOfGamutClipping(I):
    """ Clips out-of-gamut pixels. """
    I[I > 1] = 1  # any pixel is higher than 1, clip it to 1
    I[I < 0] = 0  # any pixel is below 0, clip it to 0
    return I

## DL model

In [None]:
mytransform = transforms.Compose([
    transforms.Resize((256, 256))
])

device = 'cuda'

model_checkpoint_path = os.path.join(CHECKPOINTS, 'a-unet_aug_20220918_104239_110.pth')
# net = GridNet()
net = AttentionUNet()
# net = SimpleAttentionUNet()
net.load_state_dict(torch.load(model_checkpoint_path))
net.eval()
net = net.to(device)


def infer(image):
    """image: float np array RGB image"""
    with torch.no_grad():
        input = torch.from_numpy(image)
        input = input.permute((2,0,1))
        input = mytransform(input)
        image1 = input.permute((1,2,0)).numpy()
        input = input.to(torch.float32)
        input = input.unsqueeze(0)
        output = net(input.to(device))[0]
    output = output.permute((1,2,0)).cpu().numpy()
    mapping_func = get_mapping_func(image1, output)
    return outOfGamutClipping(apply_mapping_func(image, mapping_func))

# Eval

In [None]:
def eval(name, imgin, gtimg):
    # read the image
    I_in = cv2.imread(imgin, cv2.IMREAD_COLOR)
    # read gt image
    gt = cv2.imread(gtimg, cv2.IMREAD_COLOR)
    # metadata
    metadata = get_metadata(name, 'RenderedWB_Set1', METADATA_DIR)  

    # white balance I_in
    I_corr = (infer(I_in[:,:,::-1] / 255)[:,:,::-1]*255).astype("uint8")  

    import matplotlib.pyplot as plt
    plt.figure(figsize=(20,10))
    plt.subplot(1,3,1)
    plt.imshow(I_in[:,:,::-1])
    plt.title("Input")
    plt.subplot(1,3,2)
    plt.imshow(gt[:,:,::-1])
    plt.title("Target")
    plt.subplot(1,3,3)
    plt.imshow(I_corr[:,:,::-1])
    plt.title("Corrected")
    plt.show()

    # Evaluation
    deltaE00, MSE, MAE, deltaE76 = evaluate_cc(I_corr, gt, metadata["cc_mask_area"],
                                            opt=4)
    # logger.info('DeltaE 2000: %0.2f, MSE= %0.2f, MAE= %0.2f, DeltaE 76= %0.2f\n'
    #     % (deltaE00, MSE, MAE, deltaE76))

    return deltaE00, MSE.item(), MAE, deltaE76

In [None]:
class EvalThread(threading.Thread):
  eval_log = pd.DataFrame(columns=["img_in", "img_gt", "deltaE2000", "MSE", "MAE", "deltaE76"])

  def __init__(self, part, n_part, imgs):
    super().__init__()
    self.part = part
    self.n_part = n_part
    self.imgs = imgs

  def run(self):
    imgs = self.imgs
    df = EvalThread.eval_log
    for i in range(len(imgs)):
      if i % self.n_part == self.part:
        img_in = os.path.join(IN_DIR, imgs[i] + '.jpg')
        img_gt = os.path.join(GT_DIR, imgs[i].rsplit('_', maxsplit=2)[0] + '_G_AS.png')
        deltaE00, MSE, MAE, deltaE76 = eval(imgs[i], img_in, img_gt)
        df.loc[len(df)] = {
            "img_in": img_in, 
            "img_gt": img_gt,
            "deltaE2000": deltaE00,
            "MSE": MSE,
            "MAE": MAE,
            "deltaE76": deltaE76
        }
        print(f'[AVG {len(df)}/{len(imgs)}] DeltaE 2000: {df["deltaE2000"].mean():.2f}, MSE={df["MSE"].mean():.2f}, MAE={df["MAE"].mean():.2f}, DeltaE 76={df["deltaE76"].mean():.2f}')
      # if len(df) % 10 == 9:


In [None]:
imgs = []    
with open(os.path.join(FOLD_DIR, 'fold_1.txt'), 'r') as f:
    imgs += [s.strip().split(".")[0] for s in f.readlines()]

n_part = 8
threads = []
for part in range(n_part):
  thread = EvalThread(part, n_part, imgs[:100])
  threads.append(thread)

for thread in threads:
  thread.start()
  thread.join()
