# Correct color of the STL-10 dataset

In [1]:
import cv2
import numpy as np
from skimage import color
import os
import numpy as np
import re
import torch
from torch import nn
from torchvision import transforms
from sklearn.linear_model import LinearRegression
from torchvision.io import read_image
from skimage.io import imread
import threading
import time
import pandas as pd

# Evaluation Functions

In [2]:
def calc_deltaE(source, target, color_chart_area):
  source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
  target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
  source = color.rgb2lab(source)
  target = color.rgb2lab(target)
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  delta_e = np.sqrt(np.sum(np.power(source - target, 2), 1))
  return sum(delta_e) / (np.shape(delta_e)[0] - color_chart_area)

def calc_deltaE2000(source, target, color_chart_area):
  source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
  target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
  source = color.rgb2lab(source)
  target = color.rgb2lab(target)
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  deltaE00 = deltaE2000(source, target)
  return sum(deltaE00) / (np.shape(deltaE00)[0] - color_chart_area)


def deltaE2000(Labstd, Labsample):
  kl = 1
  kc = 1
  kh = 1
  Lstd = np.transpose(Labstd[:, 0])
  astd = np.transpose(Labstd[:, 1])
  bstd = np.transpose(Labstd[:, 2])
  Cabstd = np.sqrt(np.power(astd, 2) + np.power(bstd, 2))
  Lsample = np.transpose(Labsample[:, 0])
  asample = np.transpose(Labsample[:, 1])
  bsample = np.transpose(Labsample[:, 2])
  Cabsample = np.sqrt(np.power(asample, 2) + np.power(bsample, 2))
  Cabarithmean = (Cabstd + Cabsample) / 2
  G = 0.5 * (1 - np.sqrt((np.power(Cabarithmean, 7)) / (np.power(
    Cabarithmean, 7) + np.power(25, 7))))
  apstd = (1 + G) * astd
  apsample = (1 + G) * asample
  Cpsample = np.sqrt(np.power(apsample, 2) + np.power(bsample, 2))
  Cpstd = np.sqrt(np.power(apstd, 2) + np.power(bstd, 2))
  Cpprod = (Cpsample * Cpstd)
  zcidx = np.argwhere(Cpprod == 0)
  hpstd = np.arctan2(bstd, apstd)
  hpstd[np.argwhere((np.abs(apstd) + np.abs(bstd)) == 0)] = 0
  hpsample = np.arctan2(bsample, apsample)
  hpsample = hpsample + 2 * np.pi * (hpsample < 0)
  hpsample[np.argwhere((np.abs(apsample) + np.abs(bsample)) == 0)] = 0
  dL = (Lsample - Lstd)
  dC = (Cpsample - Cpstd)
  dhp = (hpsample - hpstd)
  dhp = dhp - 2 * np.pi * (dhp > np.pi)
  dhp = dhp + 2 * np.pi * (dhp < (-np.pi))
  dhp[zcidx] = 0
  dH = 2 * np.sqrt(Cpprod) * np.sin(dhp / 2)
  Lp = (Lsample + Lstd) / 2
  Cp = (Cpstd + Cpsample) / 2
  hp = (hpstd + hpsample) / 2
  hp = hp - (np.abs(hpstd - hpsample) > np.pi) * np.pi
  hp = hp + (hp < 0) * 2 * np.pi
  hp[zcidx] = hpsample[zcidx] + hpstd[zcidx]
  Lpm502 = np.power((Lp - 50), 2)
  Sl = 1 + 0.015 * Lpm502 / np.sqrt(20 + Lpm502)
  Sc = 1 + 0.045 * Cp
  T = 1 - 0.17 * np.cos(hp - np.pi / 6) + 0.24 * np.cos(2 * hp) + \
      0.32 * np.cos(3 * hp + np.pi / 30) \
      - 0.20 * np.cos(4 * hp - 63 * np.pi / 180)
  Sh = 1 + 0.015 * Cp * T
  delthetarad = (30 * np.pi / 180) * np.exp(
    - np.power((180 / np.pi * hp - 275) / 25, 2))
  Rc = 2 * np.sqrt((np.power(Cp, 7)) / (np.power(Cp, 7) + np.power(25, 7)))
  RT = - np.sin(2 * delthetarad) * Rc
  klSl = kl * Sl
  kcSc = kc * Sc
  khSh = kh * Sh
  de00 = np.sqrt(np.power((dL / klSl), 2) + np.power((dC / kcSc), 2) +
                 np.power((dH / khSh), 2) + RT * (dC / kcSc) * (dH / khSh))
  return de00


def calc_mae(source, target, color_chart_area):
  source = np.reshape(source, [-1, 3]).astype(np.float32)
  target = np.reshape(target, [-1, 3]).astype(np.float32)
  source_norm = np.sqrt(np.sum(np.power(source, 2), 1))
  target_norm = np.sqrt(np.sum(np.power(target, 2), 1))
  norm = source_norm * target_norm
  L = np.shape(norm)[0]
  inds = norm != 0
  angles = np.sum(source[inds, :] * target[inds, :], 1) / norm[inds]
  angles[angles > 1] = 1
  f = np.arccos(angles)
  f[np.isnan(f)] = 0
  f = f * 180 / np.pi
  return sum(f) / (L - color_chart_area)



def calc_mse(source, target, color_chart_area):
  source = np.reshape(source, [-1, 1]).astype(np.float64)
  target = np.reshape(target, [-1, 1]).astype(np.float64)
  mse = sum(np.power((source - target), 2))
  return mse / ((np.shape(source)[
    0]) - color_chart_area)


def evaluate_cc(corrected, gt, color_chart_area, opt=1):
  """
    Color constancy (white-balance correction) evaluation of a given corrected
    image.
    :param corrected: corrected image
    :param gt: ground-truth image
    :param color_chart_area: If there is a color chart in the image, that is
     masked out from both images, this variable represents the number of pixels
     of the color chart.
    :param opt: determines the required error metric(s) to be reported.
         Options:
           opt = 1 delta E 2000 (default).
           opt = 2 delta E 2000 and mean squared error (MSE)
           opt = 3 delta E 2000, MSE, and mean angular eror (MAE)
           opt = 4 delta E 2000, MSE, MAE, and delta E 76
    :return: error(s) between corrected and gt images
    """

  if opt == 1:
    return calc_deltaE2000(corrected, gt, color_chart_area)
  elif opt == 2:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area)
  elif opt == 3:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area), calc_mae(corrected, gt,
                                                 color_chart_area)
  elif opt == 4:
    return calc_deltaE2000(corrected, gt, color_chart_area), calc_mse(
      corrected, gt, color_chart_area), calc_mae(
      corrected, gt, color_chart_area), calc_deltaE(corrected, gt,
                                                    color_chart_area)
  else:
    raise Exception('Error in evaluate_cc function')




def get_metadata(fileName, set, metadata_baseDir=''):
    """
    Gets metadata (e.g., ground-truth file name, chart coordinates and area).
    :param fileName: input filename
    :param set: which dataset?--options includes: 'RenderedWB_Set1',
      'RenderedWB_Set2', 'Rendered_Cube+'
    :param metadata_baseDir: metadata directory (required for Set1 only)
    :return: metadata for a given image
    evaluation_examples.py provides some examples of how to use it
    """

    fname, file_extension = os.path.splitext(fileName)  # get file parts
    name = os.path.basename(fname)  # get only filename without the directory

    if set == 'RenderedWB_Set1': # Rendered WB dataset (Set1)
        metadatafile_color = name + '_color.txt' # chart's colors info.
        metadatafile_mask = name + '_mask.txt' # chart's coordinate info.
        # get color info.
        f = open(os.path.join(metadata_baseDir, metadatafile_color), 'r')
        C = f.read()
        colors = np.zeros((3, 24))  # color chart colors
        temp = re.split(',|\n', C)
        # 3 x 24 colors in the color chart
        colors = np.reshape(np.asfarray(temp[:-1], float), (24, 3)).transpose()
        # get coordinate info
        f = open(os.path.join(metadata_baseDir, metadatafile_mask), 'r')
        C = f.read()
        temp = re.split(',|\n', C)
        # take only the first 4 elements (i.e., the color chart coordinates)
        temp = temp[0:4]
        mask = np.asfarray(temp, float)  # color chart mask coordinates
        # get ground-truth file name
        seperator = '_'
        temp = name.split(seperator)
        gt_file = seperator.join(temp[:-2])
        gt_file = gt_file + '_G_AS.png'
        # compute mask area
        mask_area = mask[2] * mask[3]
        # final metadata
        data = {"gt_filename": gt_file, "cc_colors": colors, "cc_mask": mask,
                "cc_mask_area": mask_area}

    elif set == 'RenderedWB_Set2': # Rendered WB dataset (Set2)
        data = {"gt_filename": name + file_extension, "cc_colors": None,
                "cc_mask": None, "cc_mask_area": 0}

    elif set == 'Rendered_Cube+': # Rendered Cube+
        # get ground-truth filename
        temp = name.split('_')
        gt_file = temp[0] + file_extension
        mask_area = 58373  # calibration obj's area is fixed over all images
        data = {"gt_filename": gt_file, "cc_colors": None, "cc_mask": None,
                "cc_mask_area": mask_area}
    else:
        raise Exception(
            "Invalid value for set variable. " +
            "Please use: 'RenderedWB_Set1', 'RenderedWB_Set2', 'Rendered_Cube+'")

    return data


# Prepare data

In [5]:
CHECKPOINTS = './checkpoints/'
IN_DIR = '/content/Set1_input_images_wo_CC_JPG'
GT_DIR = '/content/Set1_ground_truth_images_wo_CC'
FOLD_DIR = '/content/drive/MyDrive/CC/datasets/Set1_folds'
METADATA_DIR = '/content/Set1_input_images_metadata'

# Attention U-Net model

In [3]:
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        # number of input channels is a number of filters in the previous layer
        # number of output channels is a number of filters in the current layer
        # "same" convolutions
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class UpConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(UpConv, self).__init__()

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class AttentionBlock(nn.Module):
    """Attention block with learnable parameters"""

    def __init__(self, F_g, F_l, n_coefficients):
        """
        :param F_g: number of feature maps (channels) in previous layer
        :param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection
        :param n_coefficients: number of learnable multi-dimensional attention coefficients
        """
        super(AttentionBlock, self).__init__()

        self.W_gate = nn.Sequential(
            nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_coefficients)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_coefficients)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, gate, skip_connection):
        """
        :param gate: gating signal from previous layer
        :param skip_connection: activation from corresponding encoder layer
        :return: output activations
        """
        g1 = self.W_gate(gate)
        x1 = self.W_x(skip_connection)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        out = skip_connection * psi
        return out


class AttentionUNet(nn.Module):

    def __init__(self, img_ch=3, output_ch=3):
        super(AttentionUNet, self).__init__()

        self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = ConvBlock(img_ch, 24)
        self.Conv2 = ConvBlock(24, 48)
        self.Conv3 = ConvBlock(48, 96)
        self.Conv4 = ConvBlock(96, 192)
        self.Conv5 = ConvBlock(192, 384)

        self.Up5 = UpConv(384, 192)
        self.Att5 = AttentionBlock(F_g=192, F_l=192, n_coefficients=96)
        self.UpConv5 = ConvBlock(384, 192)

        self.Up4 = UpConv(192, 96)
        self.Att4 = AttentionBlock(F_g=96, F_l=96, n_coefficients=48)
        self.UpConv4 = ConvBlock(192, 96)

        self.Up3 = UpConv(96, 48)
        self.Att3 = AttentionBlock(F_g=48, F_l=48, n_coefficients=24)
        self.UpConv3 = ConvBlock(96, 48)

        self.Up2 = UpConv(48, 24)
        self.Att2 = AttentionBlock(F_g=24, F_l=24, n_coefficients=12)
        self.UpConv2 = ConvBlock(48, 24)

        self.Conv = nn.Conv2d(24, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        """
        e : encoder layers
        d : decoder layers
        s : skip-connections from encoder layers to decoder layers
        """
        e1 = self.Conv1(x)

        e2 = self.MaxPool(e1)
        e2 = self.Conv2(e2)

        e3 = self.MaxPool(e2)
        e3 = self.Conv3(e3)

        e4 = self.MaxPool(e3)
        e4 = self.Conv4(e4)

        e5 = self.MaxPool(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)

        s4 = self.Att5(gate=d5, skip_connection=e4)
        d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
        d5 = self.UpConv5(d5)

        d4 = self.Up4(d5)
        s3 = self.Att4(gate=d4, skip_connection=e3)
        d4 = torch.cat((s3, d4), dim=1)
        d4 = self.UpConv4(d4)

        d3 = self.Up3(d4)
        s2 = self.Att3(gate=d3, skip_connection=e2)
        d3 = torch.cat((s2, d3), dim=1)
        d3 = self.UpConv3(d3)

        d2 = self.Up2(d3)
        s1 = self.Att2(gate=d2, skip_connection=e1)
        d2 = torch.cat((s1, d2), dim=1)
        d2 = self.UpConv2(d2)

        out = self.Conv(d2)

        return out


# Inference model

## Post process

In [4]:
def kernelP(I):
    """ Kernel function: kernel(r, g, b) -> (r,g,b,rg,rb,gb,r^2,g^2,b^2,rgb,1)
        Ref: Hong, et al., "A study of digital camera colorimetric characterization
         based on polynomial modeling." Color Research & Application, 2001. """
    return (np.transpose((I[:, 0], I[:, 1], I[:, 2], I[:, 0] * I[:, 1], I[:, 0] * I[:, 2],
                          I[:, 1] * I[:, 2], I[:, 0] * I[:, 0], I[:, 1] * I[:, 1],
                          I[:, 2] * I[:, 2], I[:, 0] * I[:, 1] * I[:, 2],
                          np.repeat(1, np.shape(I)[0]))))


def get_mapping_func(image1, image2):
    """ Computes the polynomial mapping """
    image1 = np.reshape(image1, [-1, 3])
    image2 = np.reshape(image2, [-1, 3])
    m = LinearRegression().fit(kernelP(image1), image2)
    return m


def apply_mapping_func(image, m):
    """ Applies the polynomial mapping """
    sz = image.shape
    image = np.reshape(image, [-1, 3])
    result = m.predict(kernelP(image))
    result = np.reshape(result, [sz[0], sz[1], sz[2]])
    return result

def outOfGamutClipping(I):
    """ Clips out-of-gamut pixels. """
    I[I > 1] = 1  # any pixel is higher than 1, clip it to 1
    I[I < 0] = 0  # any pixel is below 0, clip it to 0
    return I

## DL model

In [6]:
mytransform = transforms.Compose([
    transforms.Resize((256, 256))
])

device = 'cuda'

model_checkpoint_path = os.path.join(CHECKPOINTS, 'a-unet_aug_20220918_104239_done.pth')
net = AttentionUNet()
net.load_state_dict(torch.load(model_checkpoint_path))
net.eval()
net = net.to(device)


def infer(image):
    """image: float np array RGB image"""
    with torch.no_grad():
        input = torch.from_numpy(image)
        input = input.permute((2,0,1))
        input = mytransform(input)
        image1 = input.permute((1,2,0)).numpy()
        input = input.to(torch.float32)
        input = input.unsqueeze(0)
        output = net(input.to(device))[0]
    output = output.permute((1,2,0)).cpu().numpy()
    mapping_func = get_mapping_func(image1, output)
    return outOfGamutClipping(apply_mapping_func(image, mapping_func))

# Correct images

In [15]:
from __future__ import print_function

import sys
import os, sys, tarfile, errno
import numpy as np
import matplotlib.pyplot as plt
    
if sys.version_info >= (3, 0, 0):
    import urllib.request as urllib # ugly but works
else:
    import urllib

try:
    from imageio import imsave
except:
    from scipy.misc import imsave

print(sys.version_info) 

# image shape
HEIGHT = 96
WIDTH = 96
DEPTH = 3

# size of a single image in bytes
SIZE = HEIGHT * WIDTH * DEPTH

# path to the directory with the data
DATA_DIR = './data'

# url of the binary data
DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'

# path to the binary train file with image data
DATA_PATH = './data/stl10_binary/train_X.bin'

# path to the binary train file with labels
LABEL_PATH = './data/stl10_binary/train_y.bin'

def read_labels(path_to_labels):
    """
    :param path_to_labels: path to the binary file containing labels from the STL-10 dataset
    :return: an array containing the labels
    """
    with open(path_to_labels, 'rb') as f:
        labels = np.fromfile(f, dtype=np.uint8)
        return labels


def read_all_images(path_to_data):
    """
    :param path_to_data: the file containing the binary images from the STL-10 dataset
    :return: an array containing all the images
    """

    with open(path_to_data, 'rb') as f:
        # read whole file in uint8 chunks
        everything = np.fromfile(f, dtype=np.uint8)

        # We force the data into 3x96x96 chunks, since the
        # images are stored in "column-major order", meaning
        # that "the first 96*96 values are the red channel,
        # the next 96*96 are green, and the last are blue."
        # The -1 is since the size of the pictures depends
        # on the input file, and this way numpy determines
        # the size on its own.

        images = np.reshape(everything, (-1, 3, 96, 96))

        # Now transpose the images into a standard image format
        # readable by, for example, matplotlib.imshow
        # You might want to comment this line or reverse the shuffle
        # if you will use a learning algorithm like CNN, since they like
        # their channels separated.
        images = np.transpose(images, (0, 3, 2, 1))
        return images


def read_single_image(image_file):
    """
    CAREFUL! - this method uses a file as input instead of the path - so the
    position of the reader will be remembered outside of context of this method.
    :param image_file: the open file containing the images
    :return: a single image
    """
    # read a single image, count determines the number of uint8's to read
    image = np.fromfile(image_file, dtype=np.uint8, count=SIZE)
    # force into image matrix
    image = np.reshape(image, (3, 96, 96))
    # transpose to standard format
    # You might want to comment this line or reverse the shuffle
    # if you will use a learning algorithm like CNN, since they like
    # their channels separated.
    image = np.transpose(image, (2, 1, 0))
    return image


def plot_image(image):
    """
    :param image: the image to be plotted in a 3-D matrix format
    :return: None
    """
    plt.imshow(image)
    plt.show()

def save_image(image, name):
    imsave("%s.png" % name, image, format="png")


def save_images(images, labels):
    print("Saving images to disk")
    i = 0
    for image in images:
        label = labels[i]
        directory = './img/' + str(label) + '/'
        try:
            os.makedirs(directory, exist_ok=True)
        except OSError as exc:
            if exc.errno == errno.EEXIST:
                pass
        filename = directory + str(i)
        # print(filename)
        image = (infer(image / 255) * 255).astype("uint8")
        save_image(image, filename)
        i = i+1
        if i % 10 == 9:
            print(f'[{i+1}/{len(images)}]')

sys.version_info(major=3, minor=9, micro=7, releaselevel='final', serial=0)


In [16]:
# test to check if the whole dataset is read correctly
images = read_all_images(DATA_PATH)
print(images.shape)

labels = read_labels(LABEL_PATH)
print(labels.shape)

# save images to disk
save_images(images, labels)

(5000, 96, 96, 3)
(5000,)
Saving images to disk
[10/5000]
[20/5000]
[30/5000]
[40/5000]
[50/5000]
[60/5000]
[70/5000]
[80/5000]
[90/5000]
[100/5000]
[110/5000]
[120/5000]
[130/5000]
[140/5000]
[150/5000]
[160/5000]
[170/5000]
[180/5000]
[190/5000]
[200/5000]
[210/5000]
[220/5000]
[230/5000]
[240/5000]
[250/5000]
[260/5000]
[270/5000]
[280/5000]
[290/5000]
[300/5000]
[310/5000]
[320/5000]
[330/5000]
[340/5000]
[350/5000]
[360/5000]
[370/5000]
[380/5000]
[390/5000]
[400/5000]
[410/5000]
[420/5000]
[430/5000]
[440/5000]
[450/5000]
[460/5000]
[470/5000]
[480/5000]
[490/5000]
[500/5000]
[510/5000]
[520/5000]
[530/5000]
[540/5000]
[550/5000]
[560/5000]
[570/5000]
[580/5000]
[590/5000]
[600/5000]
[610/5000]
[620/5000]
[630/5000]
[640/5000]
[650/5000]
[660/5000]
[670/5000]
[680/5000]
[690/5000]
[700/5000]
[710/5000]
[720/5000]
[730/5000]
[740/5000]
[750/5000]
[760/5000]
[770/5000]
[780/5000]
[790/5000]
[800/5000]
[810/5000]
[820/5000]
[830/5000]
[840/5000]
[850/5000]
[860/5000]
[870/5000]
[880