In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import pickle
from collections import defaultdict, Counter

## Steps


Train dataset loop
1. Calculate possible ab pairs and their frequency in training dataset
2. Save output space
   
All dataset loop
2. Go through dataset and soft-encode ab values of each pixel
   1. For each pixel find 5 nearest neihboirs in output space and weight them proportionally to their distance from the ground truth using
a Gaussian kernel with σ = 5.


In [2]:
def get_quantized_ab_pairs(image_lab: np.ndarray, grid: int = 10):
    ab_mat = image_lab[:, :, 1:].reshape(-1, 2)
    ab_mat_biased = ab_mat.astype(int) - 128
    ab_mat_quant = ab_mat_biased // grid * grid
    return [tuple(pair) for pair in ab_mat_quant.tolist()]


In [3]:
train_df = pd.read_pickle('../../data/ILSVRC/Metadata/train.pkl')
N = len(train_df)

train_df = train_df.sample(N).reset_index(drop=True)
ab_pairs_counter = defaultdict(int)

for i in tqdm(range(len(train_df))):
    img_path = train_df.iloc[i]['image_path']
    image = cv2.imread(img_path)

    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)

    ab_pairs_ = get_quantized_ab_pairs(image_lab)

    ab_pairs_counter_ = dict(Counter(ab_pairs_))
    for pair, count in ab_pairs_counter_.items():
        ab_pairs_counter[pair] += count


ab_pairs_counter = dict(ab_pairs_counter)

ab_pairs_set = list(ab_pairs_counter.keys())
a_values = [a for a, _ in ab_pairs_set]
b_values = [b for _, b in ab_pairs_set]
print(f"Number of unique ab pairs is: {len(ab_pairs_set)}")

# Create a scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(b_values, a_values, alpha=0.6)
plt.title('Distribution of ab value pairs')
plt.xlabel('Component b')
plt.ylabel('Component a')
plt.grid(True)
plt.xlim((-110, 110))
plt.ylim((110, -110))
plt.show()

with open('../../data/ILSVRC/Metadata/ab_pair_counts.pkl', 'wb') as f:
    pickle.dump(ab_pairs_counter, f)

  0%|          | 3457/1281167 [05:21<27:45:25, 12.79it/s] 

In [12]:
with open('../../data/ILSVRC/Metadata/ab_pair_counts.pkl', 'rb') as f:
    ab_pairs_counter = pickle.load(f)

print(ab_pairs_counter)

{(0, -10): 3088704, (0, 0): 8166602, (-10, 0): 2152956, (0, 10): 1596097, (10, 0): 500851, (10, -10): 160848, (10, -20): 57274, (10, -30): 40453, (0, -20): 413522, (-10, 10): 1022352, (-20, 10): 185682, (-10, -10): 1381942, (10, 10): 556125, (0, 20): 1204460, (-10, -20): 376321, (0, -30): 207128, (20, 10): 158539, (10, 20): 484357, (20, 0): 48022, (20, 20): 169951, (-10, 20): 448246, (-10, 30): 88769, (-20, 30): 106548, (-20, 20): 183218, (0, 30): 543646, (10, 30): 201573, (-20, 40): 23190, (20, -10): 15176, (20, 30): 99389, (30, 10): 33821, (30, 20): 47757, (-20, -10): 127977, (-10, 40): 22215, (-20, 0): 134483, (0, 40): 59206, (0, -60): 364, (0, -50): 3666, (-10, -50): 4546, (-20, -40): 71420, (-10, -40): 114508, (-20, -50): 740, (-10, -30): 161355, (-20, -30): 34225, (0, -40): 40538, (-20, -20): 107041, (20, -20): 9176, (10, -40): 25927, (20, -50): 11087, (-30, -20): 20688, (20, -40): 24190, (20, -30): 12175, (40, 20): 30158, (40, 10): 13035, (-30, 30): 60975, (-40, 20): 9958, (-30,