# Comparing implemented methods
## Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
from sklearn.model_selection import KFold

from color_correction.errors import CIELABDE, CIEDE2000
from color_correction.regressions import LCC, PCC, RPCC
from color_correction.nn import MLP, MLPExtendedTrain, MLPExposureInvariant

## Loading [SFU's surface reflectance dataset](https://www2.cs.sfu.ca/~colour/data/colour_constancy_synthetic_test_data/)
Sorry for magic numbers. We need equal ranges of data that's why we skip some fields
sometimes.

In [2]:
sfu_spd = np.loadtxt("data/reflect_db.reflect", dtype=np.float64).reshape([-1,101])[:, 5:86:]
data_size = sfu_spd.shape[0]

interp_wavelengths = np.arange(400, 721)
interpolator = make_interp_spline(interp_wavelengths[::4], sfu_spd, k=1, axis=1)
sfu_spd = interpolator(interp_wavelengths).reshape(data_size, -1, 1)

sfu_spd.shape

(1995, 321, 1)

## Loading [Nikon D5100 spectral sensivity function](https://github.com/butcherg/ssf-data/blob/master/Nikon/D5100/camspec/Nikon_D5100.csv)

In [3]:
nikon_ssf = np.loadtxt("data/Nikon_D5100.csv", dtype=np.float64, delimiter=",", usecols=(1,2,3))

interpolator = make_interp_spline(interp_wavelengths[::10], nikon_ssf, k=1, axis=0)
nikon_ssf = interpolator(interp_wavelengths)

nikon_ssf.shape

(321, 3)

## Loading [CIE XYZ color matching function](https://cie.co.at/datatable/cie-1931-colour-matching-functions-2-degree-observer)

In [4]:
xyz_cmf = np.loadtxt("data/CIE_xyz_1931_2deg.csv", dtype=np.float64, delimiter=",", usecols=(1,2,3))[40:361]
xyz_cmf.shape

(321, 3)

## Loading [D65 spectral poower distribution](https://cie.co.at/datatable/cie-standard-illuminant-d65)

In [5]:
d65_spd = np.loadtxt("data/CIE_std_illum_D65.csv", dtype=np.float64, delimiter=",", usecols=(1))[100:421].reshape(-1, 1)
d65_spd.shape

(321, 1)

## Calculating RGB RAW and CIE XYZ coordinates

In [6]:
rgb = np.trapz(sfu_spd * nikon_ssf * d65_spd, dx=1, axis=1)
xyz = np.trapz(sfu_spd * xyz_cmf * d65_spd, dx=1, axis=1)
print(rgb.shape, xyz.shape)

(1995, 3) (1995, 3)


## Calculating white point

In [7]:
white_point = np.trapz(xyz_cmf * d65_spd, dx=1, axis=0)
white_point

array([10033.22810699, 10565.8810844 , 11469.06421026])

## Performance on fixed exposure

In [8]:
methods = {"linear": LCC(),
           "poly 2 deg.": PCC(degree=2, loss="mse"),
           "root poly 2 deg.": RPCC(degree=2, loss="mse"),
           "poly opt. 2 deg.": PCC(degree=2, loss="cielabde"),
           "root poly opt. 2 deg.": RPCC(degree=2, loss="cielabde"),
           "nn": MLP(),
           "nn aug.": MLPExtendedTrain(batch_size=256),
           "nn el": MLPExposureInvariant()}
    
exposures = [0.2, 0.5, 1, 2, 5]
def exposure_test(rgb, xyz, white_point, model, exposure):
    res = model.predict(rgb * exposure)
    return CIELABDE(res, xyz * exposure, white_point * exposure)

n_splits = 5
error_statistics_cielab = np.zeros((len(methods), 4), dtype=np.float64)
error_statistics_ciede2000 = np.zeros((len(methods), 4), dtype=np.float64)
error_statistics_exps = np.zeros((len(methods), len(exposures)), dtype=np.float64)

kf = KFold(n_splits=n_splits, random_state=0, shuffle=True)
for train_index, test_index in kf.split(rgb):
    rgb_train, rgb_test = rgb[train_index], rgb[test_index]
    xyz_train, xyz_test = xyz[train_index], xyz[test_index]

    batch_statistics_cielab = np.zeros_like(error_statistics_cielab)
    batch_statistics_ciede2000 = np.zeros_like(error_statistics_ciede2000)
    batch_statistics_exps = np.zeros_like(error_statistics_exps)
    for j, model in enumerate(methods.values()):
        model.fit(rgb_train, xyz_train, white_point)
        errors_cielab = CIELABDE(model.predict(rgb_test), xyz_test, white_point)
        errors_ciede2000 = CIEDE2000(model.predict(rgb_test), xyz_test, white_point)

        batch_statistics_cielab[j] = np.array([np.mean(errors_cielab),
                                               np.max(errors_cielab),
                                               np.median(errors_cielab),
                                               np.percentile(errors_cielab, 95)])
        batch_statistics_ciede2000[j] = np.array([np.mean(errors_ciede2000),
                                                  np.max(errors_ciede2000),
                                                  np.median(errors_ciede2000),
                                                  np.percentile(errors_ciede2000, 95)])

        for k, exposure in enumerate(exposures):
            errors_exps = exposure_test(rgb_test, xyz_test, white_point, model, exposure)
            batch_statistics_exps[j][k] = np.mean(errors_exps)

    error_statistics_cielab += batch_statistics_cielab
    error_statistics_ciede2000 += batch_statistics_ciede2000
    error_statistics_exps += batch_statistics_exps

error_statistics_cielab /= n_splits
error_statistics_ciede2000 /= n_splits
error_statistics_exps /= n_splits
    
df_cielab = pd.DataFrame({"Method": methods.keys(),
                          "Mean": error_statistics_cielab[:, 0],
                          "Max": error_statistics_cielab[:, 1],
                          "Median": error_statistics_cielab[:, 2],
                          "95%": error_statistics_cielab[:, 3]}).style.hide()
df_ciede2000 = pd.DataFrame({"Method": methods.keys(),
                             "Mean": error_statistics_ciede2000[:, 0],
                             "Max": error_statistics_ciede2000[:, 1],
                             "Median": error_statistics_ciede2000[:, 2],
                             "95%": error_statistics_ciede2000[:, 3]}).style.hide()
df_exps = pd.DataFrame(columns=exposures,
                       data=error_statistics_exps)
df_exps.insert(0, "Methods", methods.keys())
df_exps = df_exps.style.hide()

### CIELAB Delta E

In [9]:
df_cielab

Method,Mean,Max,Median,95%
linear,1.589252,17.758626,0.889065,5.043187
poly 2 deg.,1.263581,13.12455,0.753671,3.683078
root poly 2 deg.,1.160634,15.185094,0.675933,3.683861
poly opt. 2 deg.,1.144871,7.28597,0.766822,3.400068
root poly opt. 2 deg.,1.086645,7.13743,0.68771,3.396306
nn,1.08989,11.127519,0.692959,3.323537
nn aug.,0.897829,7.268493,0.526111,2.920193
nn el,1.446469,13.740053,0.893253,4.55344


### CIE Delta E 2000

In [10]:
df_ciede2000

Method,Mean,Max,Median,95%
linear,0.905658,6.857275,0.651576,2.522678
poly 2 deg.,0.752311,4.148586,0.54309,2.092545
root poly 2 deg.,0.693781,6.433788,0.494083,2.031283
poly opt. 2 deg.,0.721213,3.618652,0.533472,1.907736
root poly opt. 2 deg.,0.660712,3.543607,0.487383,1.869125
nn,0.69429,3.854647,0.512192,1.809795
nn aug.,0.550765,3.374957,0.371317,1.542285
nn el,0.853429,4.237741,0.644838,2.274702


### Different exposures

In [11]:
df_exps

Methods,0.200000,0.500000,1.000000,2.000000,5.000000
linear,1.589252,1.589252,1.589252,1.589252,1.589252
poly 2 deg.,1.611825,1.425788,1.263581,1.841396,6.078506
root poly 2 deg.,1.160634,1.160634,1.160634,1.160634,1.160634
poly opt. 2 deg.,1.42573,1.279574,1.144871,1.511439,4.173929
root poly opt. 2 deg.,1.086645,1.086645,1.086645,1.086645,1.086645
nn,1.181151,1.095985,1.08989,1.092212,1.094958
nn aug.,0.878808,0.892346,0.897829,0.900925,0.902891
nn el,1.446469,1.446469,1.446469,1.446469,1.446469
