# Bianco's CNN color constancy model recreated

#### CS 7180 - Advanced Perception
#### Di Zhang & Prakriti Pritmani
#### Oct 14, 2023

## imported packages

In [2]:
import numpy as np
import cv2
import scipy.io
import progressbar as pb

## Helper Functions

In [3]:
#define progress timer class
class progress_timer:

    def __init__(self, n_iter, description="Something"):
        self.n_iter         = n_iter
        self.iter           = 0
        self.description    = description + ': '
        self.timer          = None
        self.initialize()

    def initialize(self):
        #initialize timer
        widgets = [self.description, pb.Percentage(), ' ',   
                   pb.Bar('=', '[', ']'), ' ', pb.ETA()]
        self.timer = pb.ProgressBar(widgets=widgets, maxval=self.n_iter).start()

    def update(self, q=1):
        #update timer
        self.timer.update(self.iter)
        self.iter += q

    def finish(self):
        #end timer
        self.timer.finish()
        

## Data Preprocesssing

In [None]:
def generate_train_data(train_size, set_name, patch_size):
    
    #Load ground truth illum value
    if (set_name == 'Shi-Gehler'):
        mat_name = 'real_illum_568.mat';
        key = 'real_rgb';
        path = 'C:\\Users\\phamh\\Workspace\\Dataset\\Shi_Gehler\\Train_set\\';
        
    elif (set_name == 'Canon'):
        mat_name = 'Canon600D_gt.mat';
        key = 'groundtruth_illuminants';
        path = 'C:\\Users\\phamh\\Workspace\\Dataset\\Canon_600D\\Train_set\\';
        
    illum_mat = scipy.io.loadmat('GT_Illum_Mat\\' + mat_name, squeeze_me = True, struct_as_record = False);
    ground_truth_illum = illum_mat[key];
    
    flist = glob(path + '*.png');
    number_of_train_gt = len(flist);
    
    pt = progress_timer(n_iter = number_of_train_gt, description = 'Generating Training Data :');
    
    patches_per_image = int(train_size/number_of_train_gt);

    X_train_origin, Y_train_origin, name_train = [], [], [];
    i = 0;
    patch_r, patch_c = patch_size;

    while (i < number_of_train_gt):
        
        image_number = flist[i];
        index = (image_number.replace(path ,'')).replace('.png', '');
        
        image = cv2.imread(image_number);
        n_r, n_c, _ = np.shape(image);
        total_patch = int(((n_r - n_r%patch_r)/patch_r)*((n_c - n_c%patch_c)/patch_c));
        
        img_resize = cv2.resize(image, ((n_r - n_r%patch_r), (n_c - n_c%patch_c))); 
        img_reshape = np.reshape(img_resize, (int(patch_r), -1, 3));
        
        #Create CLAHE object
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8));
    
        for j in range (0, patches_per_image):
            
            rd = randint(0, total_patch - 1);
            img_patch = img_reshape[0:patch_r, rd*patch_c:(rd+1)*patch_c];
            
            #Convert image to Lab to perform contrast normalizing
            lab= cv2.cvtColor(img_patch, cv2.COLOR_BGR2LAB);
            
            #Contrast normalizing(Stretching)
            l, a, b = cv2.split(lab);
            cl = clahe.apply(l);
            clab = cv2.merge((cl, a, b));
            
            #Convert back to BGR
            img_patch = cv2.cvtColor(clab, cv2.COLOR_LAB2BGR);
            
            img_patch = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB);
            
            X_train_origin.append(img_patch);
            Y_train_origin.append(ground_truth_illum[int(index) - 1]);
        
        name_train.append('%04d' % (int(index) - 1));
             
        i += 1;
        
        pt.update();
 
    X_train_origin = np.asarray(X_train_origin);
    Y_train_origin = np.asarray(Y_train_origin);
    
    X_train_origin = X_train_origin/255;
    max_Y = np.amax(Y_train_origin, 1);
    Y_train_origin[:, 0] = Y_train_origin[:, 0]/max_Y;
    Y_train_origin[:, 1] = Y_train_origin[:, 1]/max_Y;
    Y_train_origin[:, 2] = Y_train_origin[:, 2]/max_Y;
    
    seed = randint(1, 5000);
    np.random.seed(seed);
    X_train_origin = np.random.permutation(X_train_origin);
    
    np.random.seed(seed);
    Y_train_origin = np.random.permutation(Y_train_origin);
    
    pt.finish();
    
    return X_train_origin, Y_train_origin, name_train;

## Model

## Training

## Testing

## Result