# **chromaVive**

**NOTE:** 
As individuals with limited GPU power and computational resources, it is not feasible for us to replicate the extensive work done by Richard Zhang et al. in their seminal paper *"Colorful Image Colorization"* (ECCV 2016). However, we provide a **proof of concept** to demonstrate the fundamental approach to tackling the image colorization problem.

## **Color Space Conversion:** RGB & LAB

#### System Variables

In [None]:
IMG_PATH = 'image.jpeg'

#### Libraries Import

In [None]:
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

#### *Original Image:* BGR to RGB

In [None]:
img = cv.imread(IMG_PATH)
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
plt.imshow(img)
plt.title('Original Image')
plt.axis('off')
plt.show();

#### *Original Image:* BGR to LAB

In [None]:
img = cv.imread(IMG_PATH)
img = cv.cvtColor(img, cv.COLOR_BGR2LAB)
L, a, b = cv.split(img)

#### *LAB Color Space:* Lightness (L), Green-Red (a), Blue-Yellow (b)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 6))

ax[0].imshow(L, cmap='binary')
ax[0].set_title('L Channel')
ax[0].axis('off')

ax[1].imshow(a, cmap='Reds')
ax[1].set_title('a Channel')
ax[1].axis('off')

ax[2].imshow(b, cmap='Blues')
ax[2].set_title('b Channel')
ax[2].axis('off')

plt.show();

#### *LAB Color Space:* Lightness (L), Green-Red-Blue-Yellow (ab)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

ax[0].imshow(L, cmap='gray')
ax[0].set_title('L Channel')
ax[0].axis('off')

ax[1].imshow(cv.merge([np.zeros_like(L), a, b]), cmap='viridis')
ax[1].set_title('ab Channels')
ax[1].axis('off')

plt.show();

In [None]:
merged_img = cv.cvtColor(cv.merge([L, a, b]), cv.COLOR_LAB2RGB)
plt.imshow(merged_img)
plt.title('Original (LAB to RGB) Image')
plt.axis('off')
plt.show();

In [None]:
L, a, b

## **Bin Classification:** AB Color Space 

#### System Variables

In [None]:
k = 5
grid_size = 16

#### Libraries Import

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

#### Bin Classification: a and b channels

In [None]:
def create_ab_bins(grid_size=16):
    a_range = torch.arange(-128, 128, grid_size)
    b_range = torch.arange(-128, 128, grid_size)
    ab_bins = torch.tensor([[a, b] for a in a_range for b in b_range])

    return ab_bins

In [None]:
ab_bins = create_ab_bins(grid_size)
print(f'No. of bins: {ab_bins.shape[0]}')
print(f'ab_bins shape: {ab_bins.shape}')

In [None]:
ab_bins

#### *AB Bin Classification:* Visualization

In [None]:
plt.plot(ab_bins[:, 0], ab_bins[:, 1], 'o')
plt.xlabel('a channel')
plt.ylabel('b channel')
plt.title('a-b bins')
plt.show();

#### *K-nearest bins*: Implementation

In [None]:
def k_nearest_neighbours(ab_target, ab_reference, k=5):
    """
    Find the k nearest neighbors in the ab color space to a target point.

    Parameters:
    ----------
    ab_target : torch.Tensor
        Target point in the form of (a, b).

    ab_reference : torch.Tensor
        Reference points, shape (n, 2).

    k : int, optional
        Number of nearest neighbors to find (default is 5).

    Returns:
    -------
    torch.Tensor
        k nearest points in the ab color space.

    torch.Tensor
        Distances of the k nearest points from the target.
    """

    ab_target = ab_target.float()
    ab_reference = ab_reference.float()

    distances = torch.linalg.norm(ab_reference - ab_target, dim=1)  # (n, ) L2 distances
    k_distances = torch.argsort(distances)[:k]
    return k_distances, distances[k_distances]

def gaussian_encoding(ab_target, ab_reference, k=5, std_dev=5.0):
    """
    Calculate Gaussian weights for ab_reference based on distance to ab_target.

    Parameters:
    ----------
    ab_target : torch.Tensor
        Target point in the form of (a, b).

    ab_reference : torch.Tensor
        Reference points, shape (n, 2).

    std_dev : float, optional
        Standard deviation for the Gaussian distribution (default is 5.0).

    Returns:
    -------
    torch.Tensor
        Normalized Gaussian weights for each reference point.
    """
    
    k_distances_indices, k_distance_values = k_nearest_neighbours(ab_target, ab_reference, k)
    weights = torch.exp(-0.5 * (k_distance_values / std_dev) ** 2)  # (k, ) gaussian encoding                  
    soft_weights = weights / torch.sum(weights)                     # (k, ) normalized weights

    n = ab_reference.shape[0]
    soft_weights_encoded = torch.zeros(n)                           # (n, ) encoded weights
    soft_weights_encoded[k_distances_indices] = soft_weights
    return soft_weights_encoded

In [None]:
ab_target = torch.randint(-128, 128, (2,), dtype=torch.float32)
soft_encoded_weights = gaussian_encoding(ab_target, ab_bins, k)
soft_encoded_weights

In [None]:
print(f'Sum of weights: {soft_encoded_weights.sum():.4f}')
print(f'Max weight: {soft_encoded_weights.max():.4f}')
print(f'Min weight: {soft_encoded_weights.min():.4f}')
print(f'Soft Encoded Weights\' shape: {soft_encoded_weights.shape}')

#### *K-nearest bins*: Visualization

In [None]:
plt.figure(figsize=(8, 16))
plt.scatter(ab_bins[:, 0], ab_bins[:, 1], color='blue', s=50)
plt.scatter(ab_target[0], ab_target[1], color='red', s=200)
for i in range(len(ab_bins)):
    plt.plot([ab_target[0], ab_bins[i, 0]], [ab_target[1], ab_bins[i, 1]], color='gray', linestyle='--') if soft_encoded_weights[i] > 0 else None
plt.title('Distance from Target Point to ab_bins')
plt.xlabel('a Channel')
plt.ylabel('b Channel')
plt.gca().set_aspect('equal', adjustable='box')
plt.show();

## **ImageNet**: Downscaled Versions

#### System Variables

In [None]:
img_size = 8
path = f'./datasets/ImageNet {img_size}X{img_size}/Imagenet#_train'
extraction_path = f'datasets/extracted/LABEL/imagenet#_data_batch_'

#### Libraries Import

In [None]:
import os
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt

#### ImageNet Extraction

In [None]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo)
    return dict

In [None]:
def load_databatch(data_folder, idx, img_size=32):
    data_file = os.path.join(data_folder, 'train_data_batch_')

    d = unpickle(data_file + str(idx))
    x = d['data']
    y = d['labels']
    mean_image = d['mean']

    x = x/np.float32(255)
    mean_image = mean_image/np.float32(255)

    # Labels are indexed from 1, shift it so that indexes start at 0
    y = [i-1 for i in y]
    data_size = x.shape[0]

    x -= mean_image

    img_size2 = img_size * img_size

    x = np.dstack((x[:, :img_size2], x[:, img_size2:2*img_size2], x[:, 2*img_size2:]))
    x = x.reshape((x.shape[0], img_size, img_size, 3)).transpose(0, 3, 1, 2)

    # create mirrored images
    X_train = x[0:data_size, :, :, :]
    Y_train = y[0:data_size]
    X_train_flip = X_train[:, :, :, ::-1]
    Y_train_flip = Y_train
    X_train = np.concatenate((X_train, X_train_flip), axis=0)
    Y_train = np.concatenate((Y_train, Y_train_flip), axis=0)

    return dict(
        X_train=X_train.astype('float32'),
        Y_train=Y_train.astype('int32'),
        mean=mean_image)

#### *ImageNet Local*: Inception & Extraction

In [None]:
count = 0
for i in range(10):
    print(f'Loading batch {i+1}...')
    data = load_databatch(path.replace('#', str(img_size)), i+1, img_size)

    image_count = data['X_train'].shape[0]
    print(f'Loaded batch: {i+1} with {image_count} images')

    print('Saving data...')
    np.save(f'{extraction_path.replace('#', str(img_size)).replace('LABEL', 'X_train')}{i+1}.npy', data['X_train'])     # X_train: (N, C, H, W)
    np.save(f'{extraction_path.replace('#', str(img_size)).replace('LABEL', 'Y_train')}{i+1}.npy', data['Y_train'])     # Y_train: (N, )
    np.save(f'{extraction_path.replace('#', str(img_size)).replace('LABEL', 'mean')}{i+1}.npy', data['mean'])           # mean: (C * H * W, )

    count += image_count

print(f'Total images: {count}')

In [None]:
def extract_arrays(idx, img_size=img_size):
    data = {}
    data['X_train'] = np.load(f'{extraction_path.replace('#', str(img_size)).replace('LABEL', 'X_train')}{idx}.npy')
    data['Y_train'] = np.load(f'{extraction_path.replace('#', str(img_size)).replace('LABEL', 'Y_train')}{idx}.npy')
    data['mean'] = np.load(f'{extraction_path.replace('#', str(img_size)).replace('LABEL', 'mean')}{idx}.npy')

    return data

#### ImageNet Testing

In [None]:
idx = random.randint(1, 10)
data = extract_arrays(idx, img_size)
print('Data extracted from batch:', idx)
data['X_train'].shape, data['Y_train'].shape, data['mean'].shape

In [None]:
num_images = data['X_train'].shape[0]
index = random.randint(0, num_images-1)
data['X_train'][index].shape

In [None]:
img = data['X_train'][random.randint(0, num_images-1)].transpose(1, 2, 0)
img = np.clip(img * 255 + 128, 0, 255).astype(np.uint8) # un-normalize
plt.imshow(img);

## **Data Analysis:** Bin Classification

#### RGB to LAB Color Space Conversion

In [None]:
data['X_train'] = data['X_train'].transpose(0, 2, 3, 1)             # (N, C, H, W) -> (N, H, W, C)
data['X_train'] = (data['X_train'] * 255 + 128).astype(np.uint8)    # un-normalize
data['X_train'].shape

In [None]:
def rgb_to_lab(img_batch):
    N, H, W, C = img_batch.shape
    if C != 3:
        raise ValueError('Expected the last dimension to be 3 (RGB channels).')
    
    L_batch = []
    a_batch = []
    b_batch = []

    for i in range(N):
        lab_img = cv.cvtColor(img_batch[i], cv.COLOR_RGB2LAB)
        L, a, b = cv.split(lab_img)
        L_batch.append(L)
        a_batch.append(a)
        b_batch.append(b)

    L_batch = np.array(L_batch)
    a_batch = np.array(a_batch)
    b_batch = np.array(b_batch)

    return L_batch, a_batch, b_batch
    
L_batch, a_batch, b_batch = rgb_to_lab(data['X_train'])
L_batch.shape, a_batch.shape, b_batch.shape

In [None]:
def merge_ab(L_batch, a_batch, b_batch):
    N, H, W = L_batch.shape
    ab_batch = np.stack((a_batch, b_batch), axis=-1)

    return L_batch, ab_batch

L_batch, ab_batch = merge_ab(L_batch, a_batch, b_batch)
L_batch.shape, ab_batch.shape

#### Pixel-Level Analysis

In [None]:
ab_batch_flat = ab_batch.reshape(-1, 2)
ab_batch_flat.shape    # (N * H * W, 2)

#### **Cumulative Run:** Aggregation

In [None]:
pixel_count = 0
bin_counts = {}

for ix in range(10):
    print('Extracting data from batch:', ix+1)
    print('---------------------------------')
    print(ix)
    data = extract_arrays(ix+1, img_size)
    data['X_train'] = data['X_train'].transpose(0, 2, 3, 1)
    data['X_train'] = (data['X_train'] * 255 + 128).astype(np.uint8)
    
    L_batch, a_batch, b_batch = rgb_to_lab(data['X_train'])
    L_batch, ab_batch = merge_ab(L_batch, a_batch, b_batch)
    ab_batch_flat = ab_batch.reshape(-1, 2)

    num_pixels = ab_batch_flat.shape[0]
    print(f'Number of pixels: {num_pixels} in {L_batch.shape[0]} images at {img_size}x{img_size} resolution.')
    for i in range(num_pixels):
        if i % L_batch.shape[0] == 0:
            print(f'Processing pixel {i:8d}/{num_pixels:8d}... in batch {ix+1}')
        torch_ab = torch.tensor(ab_batch_flat[i])
        distance_key, _ = k_nearest_neighbours(torch_ab, ab_bins, 1)
        nearest_bin = (ab_bins[distance_key[0]][0].item(), ab_bins[distance_key[0]][1].item())
        bin_counts[nearest_bin] = bin_counts.get(nearest_bin, 0) + 1
    
    pixel_count += num_pixels
    print('Extracted data from batch:', ix+1)
    print('---------------------------------')

print(f'Total pixels processed: {pixel_count}')

In [None]:
bin_counts

#### **Cumulative Run:** Visualization

In [None]:
a_range = np.arange(-128, 128, grid_size)
b_range = np.arange(-128, 128, grid_size)
heatmap = np.zeros((len(a_range), len(b_range)))

for (a, b), count in bin_counts.items():
    a_index = (a + 128) // grid_size  
    b_index = (b + 128) // grid_size  
    heatmap[a_index, b_index] += count

# log transform to make the heatmap more interpretable (ln(1+x))
log_heatmap = np.log1p(heatmap)    

plt.figure(figsize=(10, 8))
plt.imshow(log_heatmap, cmap='hsv', origin='lower', aspect='auto', interpolation='nearest')

cbar = plt.colorbar()
cbar.set_label('Pixel Count')

plt.title('Pixel Distribution in AB Color Space Bins')
plt.xlabel('B bins')
plt.ylabel('A bins')
plt.xticks(ticks=np.arange(len(b_range)), labels=b_range)
plt.yticks(ticks=np.arange(len(a_range)), labels=a_range)

plt.show();
