# Apply color correction for datasets

In [32]:
import cv2
import numpy as np
from skimage import color
import os
import numpy as np
import re
import torch
from torch import nn
from torchvision import transforms
from torchvision.io import read_image
from skimage.io import imread, imsave
import threading
import time
import pandas as pd
import glob
import shutil

In [3]:
CHECKPOINTS = './checkpoints'

# Simple Color Mapping Network 

In [2]:
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        # number of input channels is a number of filters in the previous layer
        # number of output channels is a number of filters in the current layer
        # "same" convolutions
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class Kernel(nn.Module):
  def forward(self, x):
    r = x[:,0,:,:]
    g = x[:,1,:,:]
    b = x[:,2,:,:]
    return torch.stack([r,g,b,r**2,g**2,b**2,r*g,g*b,r*b,r*g*b,torch.ones_like(r)], dim=1)

class SCMN(nn.Module):
  def __init__(self, in_channels=3):
    super().__init__()
    self.Conv1 = ConvBlock(in_channels, 24)
    self.Conv2 = ConvBlock(24, 48)
    self.Conv3 = ConvBlock(48, 96)
    self.Conv4 = ConvBlock(96, 192)
    self.Conv5 = ConvBlock(192, 384)
    self.MaxPool = nn.MaxPool2d(2, 2)
    self.Flatten = nn.Flatten()
    self.Linear = nn.Linear(24576, 33)
    self.Kernel = Kernel()

  # Infer mode
  def forward(self, x):
    x2 = self.MaxPool(self.Conv1(x))
    x2 = self.MaxPool(self.Conv2(x2))
    x2 = self.MaxPool(self.Conv3(x2))
    x2 = self.MaxPool(self.Conv4(x2))
    x2 = self.MaxPool(self.Conv5(x2))
    x2 = self.Flatten(x2)
    x2 = self.Linear(x2)
    x2 = torch.reshape(x2, (-1, 3, 11)) # B x 3 x 11

    # x3 = self.Kernel(x) # B x 11 x 256 x 256
    # x3 = torch.reshape(x3, (-1, 11, 256 * 256)) # B x 11 x 65536
    
    # x4 = torch.bmm(x2, x3) # B x 3 x 65536
    # x5 = torch.reshape(x4, (-1, 3, 256, 256))

    return x2

In [4]:
mytransform = transforms.Compose([
    transforms.Resize((256, 256))
])

device = 'cuda'

model_checkpoint_path = os.path.join(CHECKPOINTS, 'scmn_20220925_111650_150.pth')
net = SCMN()
net.load_state_dict(torch.load(model_checkpoint_path))
net.eval()
net = net.to(device)

def outOfGamutClipping(I):
    """ Clips out-of-gamut pixels. """
    I[I > 1] = 1  # any pixel is higher than 1, clip it to 1
    I[I < 0] = 0  # any pixel is below 0, clip it to 0
    return I

def get_mapping_matrix(img):
  img = torch.tensor(img, dtype=torch.float32)
  img = torch.permute(img, (2, 0, 1))
  img = torch.unsqueeze(img, dim=0)
  img = img.to('cuda')
  with torch.no_grad():
    img = mytransform(img)
    out = net(img) # 1 x 3 x 11
  out = torch.squeeze(out) # 3 x 11
  mapping_matrix = out.cpu().numpy()
  return mapping_matrix

def kernel(img):
  r = img[:,:,0]
  g = img[:,:,1]
  b = img[:,:,2]
  return np.stack([r,g,b,r**2,g**2,b**2,r*g,g*b,r*b,r*g*b, np.ones_like(r)], axis=0) # 11 x height x width

def infer(img):
  M = get_mapping_matrix(img) # 3 x 11
  height, width, _ = img.shape
  img = kernel(img) # 11 x height x width
  img = img.reshape(11, -1) # (11 x width*height)
  img = np.matmul(M, img) # 3 x width * height
  img = img.reshape(3, height, width) # 3 x height x width
  img = np.transpose(img, (1,2,0)) # height x width x 3
  img = outOfGamutClipping(img)
  return img

# Apply color correction on STL-10

In [15]:
from __future__ import print_function

import sys
import os, sys, tarfile, errno
import numpy as np
import matplotlib.pyplot as plt
    
if sys.version_info >= (3, 0, 0):
    import urllib.request as urllib # ugly but works
else:
    import urllib

try:
    from imageio import imsave
except:
    from scipy.misc import imsave

print(sys.version_info) 

# image shape
HEIGHT = 96
WIDTH = 96
DEPTH = 3

# size of a single image in bytes
SIZE = HEIGHT * WIDTH * DEPTH

# path to the directory with the data
DATA_DIR = './data'

# url of the binary data
DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'

# path to the binary train file with image data
DATA_PATH = './data/stl10_binary/train_X.bin'

# path to the binary train file with labels
LABEL_PATH = './data/stl10_binary/train_y.bin'

def read_labels(path_to_labels):
    """
    :param path_to_labels: path to the binary file containing labels from the STL-10 dataset
    :return: an array containing the labels
    """
    with open(path_to_labels, 'rb') as f:
        labels = np.fromfile(f, dtype=np.uint8)
        return labels


def read_all_images(path_to_data):
    """
    :param path_to_data: the file containing the binary images from the STL-10 dataset
    :return: an array containing all the images
    """

    with open(path_to_data, 'rb') as f:
        # read whole file in uint8 chunks
        everything = np.fromfile(f, dtype=np.uint8)

        # We force the data into 3x96x96 chunks, since the
        # images are stored in "column-major order", meaning
        # that "the first 96*96 values are the red channel,
        # the next 96*96 are green, and the last are blue."
        # The -1 is since the size of the pictures depends
        # on the input file, and this way numpy determines
        # the size on its own.

        images = np.reshape(everything, (-1, 3, 96, 96))

        # Now transpose the images into a standard image format
        # readable by, for example, matplotlib.imshow
        # You might want to comment this line or reverse the shuffle
        # if you will use a learning algorithm like CNN, since they like
        # their channels separated.
        images = np.transpose(images, (0, 3, 2, 1))
        return images


def read_single_image(image_file):
    """
    CAREFUL! - this method uses a file as input instead of the path - so the
    position of the reader will be remembered outside of context of this method.
    :param image_file: the open file containing the images
    :return: a single image
    """
    # read a single image, count determines the number of uint8's to read
    image = np.fromfile(image_file, dtype=np.uint8, count=SIZE)
    # force into image matrix
    image = np.reshape(image, (3, 96, 96))
    # transpose to standard format
    # You might want to comment this line or reverse the shuffle
    # if you will use a learning algorithm like CNN, since they like
    # their channels separated.
    image = np.transpose(image, (2, 1, 0))
    return image


def plot_image(image):
    """
    :param image: the image to be plotted in a 3-D matrix format
    :return: None
    """
    plt.imshow(image)
    plt.show()

def save_image(image, name):
    imsave("%s.png" % name, image, format="png")


def save_images(images, labels):
    print("Saving images to disk")
    i = 0
    for image in images:
        label = labels[i]
        directory = './img/' + str(label) + '/'
        try:
            os.makedirs(directory, exist_ok=True)
        except OSError as exc:
            if exc.errno == errno.EEXIST:
                pass
        filename = directory + str(i)
        # print(filename)
        image = (infer(image / 255) * 255).astype("uint8")
        save_image(image, filename)
        i = i+1
        if i % 10 == 9:
            print(f'[{i+1}/{len(images)}]')

sys.version_info(major=3, minor=9, micro=7, releaselevel='final', serial=0)


In [16]:
# test to check if the whole dataset is read correctly
images = read_all_images(DATA_PATH)
print(images.shape)

labels = read_labels(LABEL_PATH)
print(labels.shape)

# save images to disk
save_images(images, labels)

(5000, 96, 96, 3)
(5000,)
Saving images to disk
[10/5000]
[20/5000]
[30/5000]
[40/5000]
[50/5000]
[60/5000]
[70/5000]
[80/5000]
[90/5000]
[100/5000]
[110/5000]
[120/5000]
[130/5000]
[140/5000]
[150/5000]
[160/5000]
[170/5000]
[180/5000]
[190/5000]
[200/5000]
[210/5000]
[220/5000]
[230/5000]
[240/5000]
[250/5000]
[260/5000]
[270/5000]
[280/5000]
[290/5000]
[300/5000]
[310/5000]
[320/5000]
[330/5000]
[340/5000]
[350/5000]
[360/5000]
[370/5000]
[380/5000]
[390/5000]
[400/5000]
[410/5000]
[420/5000]
[430/5000]
[440/5000]
[450/5000]
[460/5000]
[470/5000]
[480/5000]
[490/5000]
[500/5000]
[510/5000]
[520/5000]
[530/5000]
[540/5000]
[550/5000]
[560/5000]
[570/5000]
[580/5000]
[590/5000]
[600/5000]
[610/5000]
[620/5000]
[630/5000]
[640/5000]
[650/5000]
[660/5000]
[670/5000]
[680/5000]
[690/5000]
[700/5000]
[710/5000]
[720/5000]
[730/5000]
[740/5000]
[750/5000]
[760/5000]
[770/5000]
[780/5000]
[790/5000]
[800/5000]
[810/5000]
[820/5000]
[830/5000]
[840/5000]
[850/5000]
[860/5000]
[870/5000]
[880

# Apply color correction on PubFig83


In [39]:
PUBFIC83 = 'E:\datasets\pubfig83'
PUBFIC83_OUT = 'E:\datasets\pubfig83-corr'

try:
    shutil.rmtree(PUBFIC83_OUT)
    print("Remove directory", PUBFIC83_OUT)
except:
    pass

os.makedirs(PUBFIC83_OUT)
print("Make directory ", PUBFIC83_OUT)

Remove directory E:\datasets\pubfig83-corr
Make directory  E:\datasets\pubfig83-corr


In [40]:
persons = os.listdir(PUBFIC83)
print("Number of persons: ", len(persons))

Number of persons:  83


In [41]:
for person in persons:
    person_dir = os.path.join(PUBFIC83_OUT, person)
    os.makedirs(person_dir)
    print("Make directory ", person_dir)

Make directory  E:\datasets\pubfig83-corr\Adam Sandler
Make directory  E:\datasets\pubfig83-corr\Alec Baldwin
Make directory  E:\datasets\pubfig83-corr\Angelina Jolie
Make directory  E:\datasets\pubfig83-corr\Anna Kournikova
Make directory  E:\datasets\pubfig83-corr\Ashton Kutcher
Make directory  E:\datasets\pubfig83-corr\Avril Lavigne
Make directory  E:\datasets\pubfig83-corr\Barack Obama
Make directory  E:\datasets\pubfig83-corr\Ben Affleck
Make directory  E:\datasets\pubfig83-corr\Beyonce Knowles
Make directory  E:\datasets\pubfig83-corr\Brad Pitt
Make directory  E:\datasets\pubfig83-corr\Cameron Diaz
Make directory  E:\datasets\pubfig83-corr\Cate Blanchett
Make directory  E:\datasets\pubfig83-corr\Charlize Theron
Make directory  E:\datasets\pubfig83-corr\Christina Ricci
Make directory  E:\datasets\pubfig83-corr\Claudia Schiffer
Make directory  E:\datasets\pubfig83-corr\Clive Owen
Make directory  E:\datasets\pubfig83-corr\Colin Farrell
Make directory  E:\datasets\pubfig83-corr\Colin

In [42]:
i = 0
for person in persons:
    print(f"Person [{i+1}/83]: {person}")
    imgs = os.listdir(os.path.join(PUBFIC83, person))
    j = 0
    for img in imgs:
        image = imread(os.path.join(PUBFIC83, person, img))
        image = image / 255
        image = infer(image)
        image = image * 255
        image = image.astype("uint8")
        imsave(os.path.join(PUBFIC83_OUT, person, img), image)
        if j % 10 == 9:
            print(f"  image [{j+1}/{len(imgs)}]")
        j += 1
    i += 1

print("Done")   

Person [1/83]: Adam Sandler
  image [10/108
  image [20/108
  image [30/108
  image [40/108
  image [50/108
  image [60/108
  image [70/108
  image [80/108
  image [90/108
  image [100/108
Person [2/83]: Alec Baldwin
  image [10/103
  image [20/103
  image [30/103
  image [40/103
  image [50/103
  image [60/103
  image [70/103
  image [80/103
  image [90/103
  image [100/103
Person [3/83]: Angelina Jolie
  image [10/214
  image [20/214
  image [30/214
  image [40/214
  image [50/214
  image [60/214
  image [70/214
  image [80/214
  image [90/214
  image [100/214
  image [110/214
  image [120/214
  image [130/214
  image [140/214
  image [150/214
  image [160/214
  image [170/214
  image [180/214
  image [190/214
  image [200/214
  image [210/214
Person [4/83]: Anna Kournikova
  image [10/171
  image [20/171
  image [30/171
  image [40/171
  image [50/171
  image [60/171
  image [70/171
  image [80/171
  image [90/171
  image [100/171
  image [110/171
  image [120/171
  image [130/171
 