In [7]:
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os
import sympy as sp
import scipy
from tqdm.auto import tqdm, trange
import PIL
import gc
import datetime
import numba as nb

In [8]:
def clamp(x, x_min, x_max):
    return max(x_min, min(x_max, x))
    
def generate_y_quantisation_table_given_quality(f):
    # f is quality in range 1..100 (technically cjpeg accepts 0 too but that's converted to 1)
    assert f in range(1, 101)
    standard_y_quantisation_table = np.array([
        [16, 11, 10, 16, 24, 40, 51, 61],
        [12, 12, 14, 19, 26, 58, 60, 55],
        [14, 13, 16, 24, 40, 57, 69, 56],
        [14, 17, 22, 29, 51, 87, 80, 62],
        [18, 22, 37, 56, 68, 109, 103, 77],
        [24, 35, 55, 64, 81, 104, 113, 92],
        [49, 64, 78, 87, 103, 121, 120, 101],
        [72, 92, 95, 98, 112, 100, 103, 99]
    ])
    scaling_factor = 5000 // f if f < 50 else 200 - 2 * f
    return np.vectorize(clamp)((standard_y_quantisation_table * scaling_factor + 50) // 100, 1, 255)

# def generate_cbcr_quantisation_table_given_quality(f):
#     # f is quality in range 1..100 (technically cjpeg accepts 0 too but that's converted to 1)
#     assert f in range(1, 101)
#     standard_cbcr_quantisation_table = np.array([
#         [17, 18, 24, 47, 99, 99, 99, 99],
#         [18, 21, 26, 66, 99, 99, 99, 99],
#         [24, 26, 56, 99, 99, 99, 99, 99],
#         [47, 66, 99, 99, 99, 99, 99, 99],
#         [99, 99, 99, 99, 99, 99, 99, 99],
#         [99, 99, 99, 99, 99, 99, 99, 99],
#         [99, 99, 99, 99, 99, 99, 99, 99],
#         [99, 99, 99, 99, 99, 99, 99, 99]
#     ])
#     scaling_factor = 5000 // f if f < 50 else 200 - 2 * f
#     return np.vectorize(clamp)((standard_cbcr_quantisation_table * scaling_factor + 50) // 100, 1, 255)

In [9]:
V = np.array([
    [-2 ** 23, 2 ** 10, 2 ** 10, 2 ** 10, 0, 2 ** 10, 2 ** 10, 2 ** 10],
    [-2 ** 23, 2 ** 10, 2 ** 10, 2 ** 10, 0, 2 ** 10, 2 ** 10, 2 ** 10],
    [-2 ** 23, 2 ** 10, 2 ** 10, 2 ** 10, 0, 2 ** 10, 2 ** 10, 2 ** 10],
    [-2 ** 23, 2 ** 10, 2 ** 10, 2 ** 10, 0, 2 ** 10, 2 ** 10, 2 ** 10],
    [-2 ** 23, 2 ** 10, 2 ** 10, 2 ** 10, 0, 2 ** 10, 2 ** 10, 2 ** 10],
    [-2 ** 23, 2 ** 10, 2 ** 10, 2 ** 10, 0, 2 ** 10, 2 ** 10, 2 ** 10],
    [-2 ** 23, 2 ** 10, 2 ** 10, 2 ** 10, 0, 2 ** 10, 2 ** 10, 2 ** 10],
    [-2 ** 23, 2 ** 10, 2 ** 10, 2 ** 10, 0, 2 ** 10, 2 ** 10, 2 ** 10]
])
T = np.array([
    [8192, 11363, 10703, 9633, 8192, 6437, 4433, 2260],
    [8192, 9633, 4433, -2259, -8192, -11362, -10704, -6436],
    [8192, 6437, -4433, -11362, -8192, 2261, 10704, 9633],
    [8192, 2260, -10703, -6436, 8192, 9633, -4433, -11363],
    [8192, -2260, -10703, 6436, 8192, -9633, -4433, 11363],
    [8192, -6437, -4433, 11362, -8192, -2261, 10704, -9633],
    [8192, -9633, 4433, 2259, -8192, 11362, -10704, 6436],
    [8192, -11363, 10703, -9633, 8192, -6437, 4433, -2260]
])
def ijg_dct(X):
    assert X.shape == (8, 8)
    # the final division by 8 is technically combined with quantisation
    # but quantisation is not done here
    # I take C's right shift to round down towards negative infinity
    return (((T.T @ (((X @ T) + V) >> 11)) + (2 ** 14)) >> 15) >> 3
    # return np.floor(np.floor(((T.T @ np.floor(((X @ T) + V) / (2 ** 11))) + (2 ** 14)) / (2 ** 15)) / 8)  # Equivalent

# MLE Estimation of Quantisation Table
There is no need for arbitrary precision arithmetic here, unlike in Exact Recompression. Standard Numpy arrays may be used instead of Sympy number types.

In [10]:
def factors_of(x):
    return [i for i in range(1, x + 1) if x % i == 0]

def generate_D(x):
    if x in (0, 4):
        return 2
    elif x in (2, 6):
        return np.sqrt(2)
        # return 2 * np.cos(np.pi / 4)
    else:
        return np.sqrt(2) * np.cos(np.pi / 8)
        # return 2 * np.cos(np.pi / 4) * np.cos(np.pi / 8)
def generate_B(m, n):
    # width of the Gaussians around each peak in the model
    return generate_D(m) * generate_D(n)
B = np.array([generate_B(m, n) for m in range(8) for n in range(8)]).reshape(8, 8)  # all the B(m, n)

def pdf_G(evalat, centre, B_mn):
    # normalisation Z doesn't matter since it's constant for a particular mn
    if np.abs(evalat - centre) > B_mn:
        return 0
    return np.exp(-6 * ((evalat - centre) ** 2))

def pdf_Laplace(evalat, b_MLE):
    return np.exp(-np.abs(evalat) / b_MLE) / (2 * b_MLE)

def argmax_summand(yprime_mn_s, q_candidate, B_mn, b_MLE):
    # yprime_mn_s: the yprime_mn_s value, a DCT coefficient we obtain after applying ijg_dct(), used in integral bounds
    # q_candidate: one of the summands in the argmax
    # B_mn: for generating k_to_sum, which is used for computing the truncated Gaussian PDF, pdf_G
    # b_MLE: for computing the Laplace PDF, pdf_Laplace
    k_to_sum = list(range(
        int(np.trunc(-2 * B_mn / q_candidate)), 
        int(np.trunc(2 * B_mn / q_candidate)) + 1
    ))  # integers -2B/q <= k <= 2B/q
    def integrand(y):
        r = np.round(y / q_candidate)
        return np.sum([pdf_G(y, (r + k) * q_candidate, B_mn) for k in k_to_sum]) * pdf_Laplace(r * q_candidate, b_MLE)
        # return np.sum([pdf_G(y, (r + k) * q_candidate, B_mn) for k in k_to_sum]) * pdf_Laplace(r * q_candidate, b_MLE)
    return np.log(scipy.integrate.quad(integrand, yprime_mn_s - 0.5, yprime_mn_s + 0.5)[0])

## `test2.png`

In [11]:
file_name = "test2"  # test2 or test3
image = PIL.Image.open(f"{file_name}.png")
y_channel = np.asarray(image)

H, W = y_channel.shape
assert H % 8 == 0 and W % 8 == 0
count_blocks_vertical = H // 8
count_blocks_horizontal = W // 8

blocks = [
    y_channel[(8 * block_i):(8 + 8 * block_i), (8 * block_j):(8 + 8 * block_j)] 
    for block_i in range(count_blocks_vertical) 
    for block_j in range(count_blocks_horizontal)
]
print(f"Total {len(blocks)} blocks before filtering")
blocks = [b for b in blocks if np.min(b) != 0 and np.max(b) != 255 and np.min(b) != np.max(b)]
print(f"Total {len(blocks)} blocks after filtering")

# perform DCT on each block
print("Applying IJG DCT on each block...")
yprime_blocks = [ijg_dct(block) for block in tqdm(blocks)]

Total 190512 blocks before filtering
Total 170428 blocks after filtering
Applying IJG DCT on each block...


  0%|          | 0/170428 [00:00<?, ?it/s]

In [12]:
# This cell's values were obtained manually from inspecting the histograms
# Q, Q + 1 and Q - 1
highest_peaks_outside_main_lobe = {
    (0, 0): 0,  # skip; does not conform to Normal distribution. Peak at 9
    (0, 1): 6,
    (0, 2): 6,
    (0, 3): 9,
    (0, 4): 13,
    (0, 5): 22,
    (0, 6): 29,
    (0, 7): 33,
    (1, 0): 7,
    (1, 1): 7,
    (1, 2): 8,
    (1, 3): 11,
    (1, 4): 15,
    (1, 5): 32,
    (1, 6): 34,
    (1, 7): 31,
    (2, 0): 8,
    (2, 1): 7,
    (2, 2): 9,
    (2, 3): 13,
    (2, 4): 22,
    (2, 5): 32,
    (2, 6): 39,
    (2, 7): 30,
    (3, 0): 8,
    (3, 1): 10,
    (3, 2): 12,
    (3, 3): 16,
    (3, 4): 28,
    (3, 5): 50,
    (3, 6): 46,
    (3, 7): 0,  # no peak detected outside main lobe
    (4, 0): 10,
    (4, 1): 12,
    (4, 2): 21,
    (4, 3): 31,
    (4, 4): 38,
    (4, 5): 0,  # no peak detected outside main lobe
    (4, 6): 0,  # no peak detected outside main lobe
    (4, 7): 0,  # no peak detected outside main lobe
    (5, 0): 13,
    (5, 1): 20,
    (5, 2): 31,
    (5, 3): 36,
    (5, 4): 0,  # no peak detected outside main lobe
    (5, 5): 0,  # no peak detected outside main lobe
    (5, 6): 0,  # no peak detected outside main lobe
    (5, 7): 0,  # no peak detected outside main lobe
    (6, 0): 27,
    (6, 1): 36,
    (6, 2): 44,
    (6, 3): 49,
    (6, 4): 0,  # no peak detected outside main lobe
    (6, 5): 0,  # no peak detected outside main lobe
    (6, 6): 0,  # no peak detected outside main lobe
    (6, 7): 0,  # no peak detected outside main lobe
    (7, 0): 40,
    (7, 1): 52,
    (7, 2): 53,
    (7, 3): 0,  # no peak detected outside main lobe
    (7, 4): 0,  # no peak detected outside main lobe ? probably not a peak at 3
    (7, 5): 0,  # no peak detected outside main lobe
    (7, 6): 0,  # no peak detected outside main lobe
    (7, 7): 0,  # no peak detected outside main lobe
}
argmax_q_candidates = {
    k: list(set(factors_of(v - 1) + factors_of(v) + factors_of(v + 1)))
    for k, v in highest_peaks_outside_main_lobe.items() if v != 0
}
# test2.png likely compressed with lowest setting f=72
# because this matrix looks very much like the Y quantisation table for f=72
np.array(list(highest_peaks_outside_main_lobe.values())).reshape(8, 8)

array([[ 0,  6,  6,  9, 13, 22, 29, 33],
       [ 7,  7,  8, 11, 15, 32, 34, 31],
       [ 8,  7,  9, 13, 22, 32, 39, 30],
       [ 8, 10, 12, 16, 28, 50, 46,  0],
       [10, 12, 21, 31, 38,  0,  0,  0],
       [13, 20, 31, 36,  0,  0,  0,  0],
       [27, 36, 44, 49,  0,  0,  0,  0],
       [40, 52, 53,  0,  0,  0,  0,  0]])

In [13]:
generate_y_quantisation_table_given_quality(72)

array([[ 9,  6,  6,  9, 13, 22, 29, 34],
       [ 7,  7,  8, 11, 15, 32, 34, 31],
       [ 8,  7,  9, 13, 22, 32, 39, 31],
       [ 8, 10, 12, 16, 29, 49, 45, 35],
       [10, 12, 21, 31, 38, 61, 58, 43],
       [13, 20, 31, 36, 45, 58, 63, 52],
       [27, 36, 44, 49, 58, 68, 67, 57],
       [40, 52, 53, 55, 63, 56, 58, 55]])

In [None]:
q_MLE = np.full((8, 8), None)
for m in trange(8):
    for n in trange(8, leave=False):
        if (m, n) not in argmax_q_candidates:
            continue  # skip for reasons explained in argmax_candidates
            
        B_mn = B[m, n]
        yprime_mn_ss = ([block[m, n] for block in yprime_blocks])  # all the yprime_mn across all blocks (i.e. collection of all DCT coefficients (m, n) from across all blocks)
        b_MLE = np.mean(np.abs(yprime_mn_ss))  # MLE estimator for the scale parameter in the underlying Laplace distribution for DCT coefficient (m, n)

        argmax_q = None
        max_q = None
        for q_candidate in tqdm(argmax_q_candidates[(m, n)], leave=False):
            sum = 0
            for yprime_mn_s in tqdm(yprime_mn_ss, leave=False):
                sum += argmax_summand(yprime_mn_s, q_candidate, B_mn, b_MLE)
            if max_q is None or sum > max_q:
                max_q = sum
                argmax_q = q_candidate
            
        q_MLE[m, n] = argmax_q
        print(f"q_MLE[{m}, {n}] = {argmax_q}")

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/170428 [00:00<?, ?it/s]

  0%|          | 0/170428 [00:00<?, ?it/s]

  0%|          | 0/170428 [00:00<?, ?it/s]

  0%|          | 0/170428 [00:00<?, ?it/s]

  0%|          | 0/170428 [00:00<?, ?it/s]

  0%|          | 0/170428 [00:00<?, ?it/s]

q_MLE[0, 1] = 1


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/170428 [00:00<?, ?it/s]

  0%|          | 0/170428 [00:00<?, ?it/s]

## `test3.png`

In [None]:
file_name = "test3"  # test2 or test3
image = PIL.Image.open(f"{file_name}.png")
y_channel = np.asarray(image)

H, W = y_channel.shape
assert H % 8 == 0 and W % 8 == 0
count_blocks_vertical = H // 8
count_blocks_horizontal = W // 8

blocks = [
    y_channel[(8 * block_i):(8 + 8 * block_i), (8 * block_j):(8 + 8 * block_j)] 
    for block_i in range(count_blocks_vertical) 
    for block_j in range(count_blocks_horizontal)
]
print(f"Total {len(blocks)} blocks before filtering")
blocks = [b for b in blocks if np.min(b) != 0 and np.max(b) != 255 and np.min(b) != np.max(b)]
print(f"Total {len(blocks)} blocks after filtering")

# perform DCT on each block
print("Applying IJG DCT on each block...")
yprime_blocks = [ijg_dct(block) for block in tqdm(blocks)]

In [None]:
# This cell's values were obtained manually from inspecting the histograms
# Q, Q + 1 and Q - 1
highest_peaks_outside_main_lobe = {
    (0, 0): 0,  # skip; does not conform to Normal distribution. peak at 10
    (0, 1): 7,
    (0, 2): 6,
    (0, 3): 10,
    (0, 4): 14,
    (0, 5): 24,
    (0, 6): 31,
    (0, 7): 37,
    (1, 0): 7,
    (1, 1): 7,
    (1, 2): 8,
    (1, 3): 11,
    (1, 4): 16,
    (1, 5): 35,
    (1, 6): 35,
    (1, 7): 0,  # no peak detected outside main lobe
    (2, 0): 8,
    (2, 1): 8,
    (2, 2): 10,
    (2, 3): 14,
    (2, 4): 24,
    (2, 5): 34,
    (2, 6): 41,
    (2, 7): 34,
    (3, 0): 8,
    (3, 1): 10,
    (3, 2): 13,
    (3, 3): 17,
    (3, 4): 31,
    (3, 5): 49,
    (3, 6): 0,  # no peak detected outside main lobe ? probably not a peak at 4
    (3, 7): 0,  # no peak detected outside main lobe
    (4, 0): 11,
    (4, 1): 13,
    (4, 2): 22,
    (4, 3): 34,
    (4, 4): 40,
    (4, 5): 0,  # no peak detected outside main lobe ? probably not a peak at 3
    (4, 6): 0,  # no peak detected outside main lobe
    (4, 7): 0,  # no peak detected outside main lobe
    (5, 0): 14,
    (5, 1): 21,
    (5, 2): 33,
    (5, 3): 35,
    (5, 4): 0,  # no peak detected outside main lobe
    (5, 5): 0,  # no peak detected outside main lobe
    (5, 6): 0,  # no peak detected outside main lobe
    (5, 7): 0,  # no peak detected outside main lobe
    (6, 0): 29,
    (6, 1): 38,
    (6, 2): 0,  # no peak detected outside main lobe
    (6, 3): 0,  # no peak detected outside main lobe
    (6, 4): 0,  # no peak detected outside main lobe
    (6, 5): 0,  # no peak detected outside main lobe
    (6, 6): 0,  # no peak detected outside main lobe
    (6, 7): 0,  # no peak detected outside main lobe  ? probably not a peak at 3
    (7, 0): 43,  # suspicious... non-symmetric. small peak at 35
    (7, 1): 0,  # no peak detected outside main lobe
    (7, 2): 0,  # no peak detected outside main lobe
    (7, 3): 0,  # no peak detected outside main lobe
    (7, 4): 0,  # no peak detected outside main lobe ? probably not a peak at -4
    (7, 5): 0,  # no peak detected outside main lobe
    (7, 6): 0,  # no peak detected outside main lobe
    (7, 7): 0,  # no peak detected outside main lobe
}
argmax_q_candidates = {
    k: list(set(factors_of(v - 1) + factors_of(v) + factors_of(v + 1)))
    for k, v in highest_peaks_outside_main_lobe.items() if v != 0
}
# test3.png likely compressed with lowest setting f=70
# because this matrix looks very much like the Y quantisation table for f=70
np.array(list(highest_peaks_outside_main_lobe.values())).reshape(8, 8)

In [None]:
generate_y_quantisation_table_given_quality(70)

In [None]:
generate_y_quantisation_table_given_quality(76)

In [None]:
q_MLE = np.full((8, 8), None)
for m in trange(8):
    for n in trange(8, leave=False):
        if (m, n) not in argmax_q_candidates:
            continue  # skip for reasons explained in argmax_candidates
            
        B_mn = B[m, n]
        yprime_mn_ss = [block[m, n] for block in yprime_blocks]  # all the yprime_mn across all blocks (i.e. collection of all DCT coefficients (m, n) from across all blocks)
        b_MLE = np.mean(np.abs(yprime_mn_ss))  # MLE estimator for the scale parameter in the underlying Laplace distribution for DCT coefficient (m, n)

        argmax_q = None
        max_q = None
        for q_candidate in tqdm(argmax_q_candidates[(m, n)], leave=False):
            sum = 0
            for yprime_mn_s in tqdm(yprime_mn_ss, leave=False):
                sum += argmax_summand(yprime_mn_s, q_candidate, B_mn, b_MLE)
            if max_q is None or sum > max_q:
                max_q = sum
                argmax_q = q_candidate
            
        q_MLE[m, n] = argmax_q
        print(f"q_MLE[{m}, {n}] = {argmax_q}")

The following was tried (maximising the joint probability as mentioned in the paper's $\S IV$. However it's even slower than the above optimisation code.

In [None]:
file_name = "test3"  # test2 or test3
image = PIL.Image.open(f"{file_name}.png")
y_channel = np.asarray(image)

H, W = y_channel.shape
assert H % 8 == 0 and W % 8 == 0
count_blocks_vertical = H // 8
count_blocks_horizontal = W // 8

blocks = [
    y_channel[(8 * block_i):(8 + 8 * block_i), (8 * block_j):(8 + 8 * block_j)] 
    for block_i in range(count_blocks_vertical) 
    for block_j in range(count_blocks_horizontal)
]
print(f"Total {len(blocks)} blocks before filtering")
blocks = [b for b in blocks if np.min(b) != 0 and np.max(b) != 255 and np.min(b) != np.max(b)]
print(f"Total {len(blocks)} blocks after filtering")

# perform DCT on each block
print("Applying IJG DCT on each block...")
yprime_blocks = [ijg_dct(block) for block in tqdm(blocks)]

f_min = 70
argmax_f = None
max_f = None
for f_candidate in trange(f_min, 101):
    q = generate_y_quantisation_table_given_quality(f_candidate)
    sum = 0
    for m in trange(8, leave=False):
        for n in trange(8, leave=False):
            q_mn = q[m, n]
            B_mn = B[m, n]
            yprime_mn_ss = [block[m, n] for block in yprime_blocks]  # all the yprime_mn across all blocks (i.e. collection of all DCT coefficients (m, n) from across all blocks)
            b_MLE = np.mean(np.abs(yprime_mn_ss))  # MLE estimator for the scale parameter in the underlying Laplace distribution for DCT coefficient (m, n)
            for yprime_mn_s in tqdm(yprime_mn_ss, leave=False):
                sum += argmax_summand(yprime_mn_s, q_mn, B_mn, b_MLE)
    if max_f is None or sum > max_f:
        max_f = sum
        argmax_f = f_candidate
        print(f"argmax_q = {argmax_q}")

## `test1c.png`
Similar code as for `test2.png` and `test3.png` can be used here. Just need to read the `test1c.png`, then perform mapping from $RGB$ to $YCbCr$, then obtain the $Y$-channel.

In [None]:
# This cell's values were obtained manually from inspecting the histograms
# Q, Q + 1 and Q - 1
highest_peaks_outside_main_lobe = {
    (0, 0): 0,  # skip; does not conform to Normal distribution. Peak at 9.
    (0, 1): 6,
    (0, 2): 6,
    (0, 3): 9,
    (0, 4): 14,  # suspicious... non-symmetric? small peak at -67, -53, -36, 
    (0, 5): 23,  # suspicious... non-symmetric. small peak at 12&13
    (0, 6): 30,
    (0, 7): 35,
    (1, 0): 7,
    (1, 1): 7,
    (1, 2): 8,
    (1, 3): 11,  # semi-suspicious... non-symmetric? small peaking out at 3&7 hmm...
    (1, 4): 15,
    (1, 5): 34,
    (1, 6): 35,
    (1, 7): 32,
    (2, 0): 8,
    (2, 1): 8,
    (2, 2): 9,
    (2, 3): 14,
    (2, 4): 23,
    (2, 5): 32,
    (2, 6): 39,
    (2, 7): 31,  # funny gap 31 33
    (3, 0): 8,
    (3, 1): 10,
    (3, 2): 13,
    (3, 3): 17,
    (3, 4): 30,  # suspicious... non-symmetric. small peak at 22
    (3, 5): 50,
    (3, 6): 0,  # no peak detected outside main lobe ? probably not a peak at 3
    (3, 7): 35,
    (4, 0): 10,
    (4, 1): 13,
    (4, 2): 21,
    (4, 3): 32,  # suspicious... non-symmetric. small peak at -7
    (4, 4): 39,
    (4, 5): 0,  # no peak detected outside main lobe ? probably not a peak at 3
    (4, 6): 0,  # no peak detected outside main lobe
    (4, 7): 0,  # no peak detected outside main lobe
    (5, 0): 14,
    (5, 1): 20,
    (5, 2): 32,
    (5, 3): 37,
    (5, 4): 0,  # no peak detected outside main lobe ? probably not a peak at 5
    (5, 5): 0,  # no peak detected outside main lobe
    (5, 6): 0,  # no peak detected outside main lobe
    (5, 7): 0,  # no peak detected outside main lobe ? probably not a peak at 3
    (6, 0): 28,
    (6, 1): 37,
    (6, 2): 45,
    (6, 3): 0,  # no peak detected outside main lobe
    (6, 4): 0,  # no peak detected outside main lobe
    (6, 5): 0,  # no peak detected outside main lobe
    (6, 6): 0,  # no peak detected outside main lobe
    (6, 7): 0,  # no peak detected outside main lobe 
    (7, 0): 42,
    (7, 1): 53, 
    (7, 2): 0,  # no peak detected outside main lobe
    (7, 3): 0,  # no peak detected outside main lobe
    (7, 4): 0,  # no peak detected outside main lobe ? probably not a peak at -4
    (7, 5): 0,  # no peak detected outside main lobe
    (7, 6): 0,  # no peak detected outside main lobe
    (7, 7): 0,  # no peak detected outside main lobe ? probably not a peak at -4
}
argmax_q_candidates = {
    k: list(set(factors_of(v - 1) + factors_of(v) + factors_of(v + 1)))
    for k, v in highest_peaks_outside_main_lobe.items() if v != 0
}

In [None]:
# test1c.png likely compressed with lowest setting f=71
# because this matrix looks very much like the Y quantisation table for f=71
np.array(list(highest_peaks_outside_main_lobe.values())).reshape(8, 8)

In [None]:
generate_y_quantisation_table_given_quality(71)

In [None]:
generate_y_quantisation_table_given_quality(78)

## `test2c.png`

In [None]:
# This cell's values were obtained manually from inspecting the histograms
# Q, Q + 1 and Q - 1
highest_peaks_outside_main_lobe = {
    (0, 0): 0,  # skip; does not conform to Normal distribution. peak at 9
    (0, 1): 6,
    (0, 2): 6,
    (0, 3): 9,
    (0, 4): 13,
    (0, 5): 22,
    (0, 6): 29,
    (0, 7): 34,  # semi-suspicious... non-symmetric? small peaking out at 28 hmm...
    (1, 0): 7,
    (1, 1): 7,
    (1, 2): 8,
    (1, 3): 11,
    (1, 4): 15,
    (1, 5): 32,
    (1, 6): 34,
    (1, 7): 31,
    (2, 0): 8,
    (2, 1): 7,
    (2, 2): 9,
    (2, 3): 13,
    (2, 4): 22,
    (2, 5): 32,
    (2, 6): 39,
    (2, 7): 31,
    (3, 0): 8,
    (3, 1): 10,
    (3, 2): 12,
    (3, 3): 16,
    (3, 4): 29,
    (3, 5): 49,
    (3, 6): 46,
    (3, 7): 35,
    (4, 0): 10,
    (4, 1): 12,
    (4, 2): 21,
    (4, 3): 31,
    (4, 4): 37,  # or 39. (edit: quantisation table f=72 is 38)
    (4, 5): 0,  # no peak detected outside main lobe
    (4, 6): 0,  # no peak detected outside main lobe
    (4, 7): 0,  # no peak detected outside main lobe
    (5, 0): 13,
    (5, 1): 20,
    (5, 2): 31,
    (5, 3): 36,
    (5, 4): 0,  # no peak detected outside main lobe
    (5, 5): 0,  # no peak detected outside main lobe ? probably not a peak at 3
    (5, 6): 0,  # no peak detected outside main lobe
    (5, 7): 0,  # no peak detected outside main lobe
    (6, 0): 27,
    (6, 1): 36,
    (6, 2): 44,
    (6, 3): 49,  # or 46. Around 47/48? (edit: quantisation table f=72 is 49)
    (6, 4): 0,  # no peak detected outside main lobe ? probably not a peak at 3
    (6, 5): 0,  # no peak detected outside main lobe
    (6, 6): 0,  # no peak detected outside main lobe
    (6, 7): 0,  # no peak detected outside main lobe ? probably not a peak at 3
    (7, 0): 41,
    (7, 1): 52,
    (7, 2): 53,
    (7, 3): 0,  # no peak detected outside main lobe
    (7, 4): 0,  # no peak detected outside main lobe
    (7, 5): 0,  # no peak detected outside main lobe
    (7, 6): 0,  # no peak detected outside main lobe
    (7, 7): 0,  # no peak detected outside main lobe
}
argmax_q_candidates = {
    k: list(set(factors_of(v - 1) + factors_of(v) + factors_of(v + 1)))
    for k, v in highest_peaks_outside_main_lobe.items() if v != 0
}

In [None]:
# test2c.png likely compressed with lowest setting f=72
# because this matrix looks very much like the Y quantisation table for f=72
np.array(list(highest_peaks_outside_main_lobe.values())).reshape(8, 8)

In [None]:
generate_y_quantisation_table_given_quality(72)

# Exact Recompression

## 2.1 Colour-space conversion

### (a) Greyscale images entry point

In [None]:
file_name = "test3"  # test2 or test3

In [None]:
image = PIL.Image.open(f"{file_name}.png")
bar_y_channel = np.asarray(image)
bar_y_channel = np.vectorize(lambda v: sp.FiniteSet(v))(Y_channel)  # sp.Interval(v, v) simplifies to sp.FiniteSet(v) and the latter is faster

### (b) Coloured images entry point

In [None]:
# RGB values are represented with (24-bit) integers from 0 to 2^24 - 1.
def clamp(x, x_min, x_max):
    return max(x_min, min(x_max, x))
    
# We want to map each RGB value u_xy to the set of all YCbCr values ddot_v_xy that maps to u_xy via rgb_to_ycbcr.
def ycbcr_int_to_rgb_int(ycbcr_int):
    # assumes ycbcr_int is an integer in [0..2^24-1].
    y = (ycbcr_int >> 16)
    cb = (ycbcr_int >> 8) & 0xff
    cr = ycbcr_int & 0xff
    r = y + ((91881 * (cr - 128) + 32768) >> 16)
    g = y + ((-22553 * (cb - 128) - 46802 * (cr - 128) + 32768) >> 16)
    b = y + ((116130 * (cb - 128) + 32768) >> 16)
    r = clamp(r, 0, 255)
    g = clamp(g, 0, 255)
    b = clamp(b, 0, 255)
    return (r << 16) + (g << 8) + b

In [None]:
if os.path.isfile("rgb_int_to_ycbcr_ints.pkl"):
    with open("rgb_int_to_ycbcr_ints.pkl", "rb") as f:
        # Occupies a few GB of RAM
        rgb_int_to_ycbcr_ints = pickle.load(f)
else:
    print("Allocation of reverse map...")
    rgb_int_to_ycbcr_ints = {i: [] for i in trange(2 ** 24)}
    print("Calculation of reverse mapping...")
    for ycbcr_int in trange(2 ** 24):
        rgb_int_to_ycbcr_ints[ycbcr_int_to_rgb_int(ycbcr_int)].append(ycbcr_int)
    print("Conversion of list to tuple...")
    for ycbcr_int in trange(2 ** 24):
        rgb_int_to_ycbcr_ints[ycbcr_int_to_rgb_int(ycbcr_int)] = tuple(rgb_int_to_ycbcr_ints[ycbcr_int_to_rgb_int(ycbcr_int)])
    with open("rgb_int_to_ycbcr_ints.pkl", "wb") as f:
        pickle.dump(rgb_int_to_ycbcr_ints, f)

In [None]:
def rgb_int_to_y_interval(rgb_int, pbar):
    ycbcr_ints = rgb_int_to_ycbcr_ints[rgb_int]
    ys = [(ycbcr_int >> 16) for ycbcr_int in ycbcr_ints]
    pbar.update()
    return sp.Interval(min(ys), max(ys))

def rgb_int_to_cb_set(rgb_int, pbar):
    ycbcr_ints = rgb_int_to_ycbcr_ints[rgb_int]
    cbs = [((ycbcr_int >> 8) & 0xff) for ycbcr_int in ycbcr_ints]
    pbar.update()
    return set(cbs)

def rgb_int_to_cr_set(rgb_int, pbar):
    ycbcr_ints = rgb_int_to_ycbcr_ints[rgb_int]
    crs = [(ycbcr_int & 0xff) for ycbcr_int in ycbcr_ints]
    pbar.update()
    return set(crs)

In [None]:
file_name = "test1c"  # test1c or test2c

In [None]:
image = PIL.Image.open(f"{file_name}.png")
rgb_channels = np.asarray(image).astype(int)
rgb_int_image = (rgb_channels[:, :, 0] << 16) + (rgb_channels[:, :, 1] << 8) + (rgb_channels[:, :, 2])
with tqdm(total=np.size(rgb_int_image)) as pbar:
    bar_y_channel = np.vectorize(rgb_int_to_y_interval)(rgb_int_image, pbar)
    with open(f"{file_name}_bar_y_channel.pkl", "wb") as f:
        pickle.dump(bar_y_channel, f)
del bar_y_channel

with tqdm(total=np.size(rgb_int_image)) as pbar:
    ddot_cb_channel = np.vectorize(rgb_int_to_cb_set)(rgb_int_image, pbar)
    with open(f"{file_name}_ddot_cb_channel.pkl", "wb") as f:
        pickle.dump(ddot_cb_channel, f)
del ddot_cb_channel

with tqdm(total=np.size(rgb_int_image)) as pbar:
    ddot_cr_channel = np.vectorize(rgb_int_to_cr_set)(rgb_int_image, pbar)
    with open(f"{file_name}_ddot_cr_channel.pkl", "wb") as f:
        pickle.dump(ddot_cr_channel, f)
del ddot_cr_channel
        
del rgb_int_to_y_interval
del rgb_int_to_cb_set
del rgb_int_to_cr_set
del rgb_int_to_ycbcr_ints
gc.collect()

In [None]:
with open(f"{file_name}_bar_y_channel.pkl", "rb") as f:
    bar_y_channel = pickle.load(f)

## 2.2 Chroma downsampling (broken)

In [None]:
# c stands in for cb or cr
def bar_wij_given_c_xevenyeven(c, bar_wim1jm1, bar_wijm1, bar_wim1j):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - bar_wim1jm1.inf - 3 * bar_wijm1.inf - 3 * bar_wim1j.inf, 9)),
        sp.floor(sp.Rational(16 * c + 7 - bar_wim1jm1.sup - 3 * bar_wijm1.sup - 3 * bar_wim1j.sup, 9))
    )
def bar_wim1j_given_c_xevenyeven(c, bar_wim1jm1, bar_wijm1, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - bar_wim1jm1.inf - 3 * bar_wijm1.inf - 9 * bar_wij.inf, 3)),
        sp.floor(sp.Rational(16 * c + 7 - bar_wim1jm1.sup - 3 * bar_wijm1.sup - 9 * bar_wij.sup, 3))
    )
def bar_wijm1_given_c_xevenyeven(c, bar_wim1jm1, bar_wim1j, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - bar_wim1jm1.inf - 3 * bar_wim1j.inf - 9 * bar_wij.inf, 3)),
        sp.floor(sp.Rational(16 * c + 7 - bar_wim1jm1.sup - 3 * bar_wim1j.sup - 9 * bar_wij.sup, 3))
    )
def bar_wim1jm1_given_c_xevenyeven(c, bar_wijm1, bar_wim1j, bar_wij):
    return sp.Interval(
        sp.ceiling(16 * c - 8 - 3 * bar_wijm1.inf - 3 * bar_wim1j.inf - 9 * bar_wij.inf),
        sp.floor(16 * c + 7 - 3 * bar_wijm1.sup - 3 * bar_wim1j.sup - 9 * bar_wij.sup)
    )

def bar_wij_given_c_xoddyeven(c, bar_wim1jm1, bar_wijm1, bar_wim1j):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - 3 * bar_wim1jm1.inf - bar_wijm1.inf - 9 * bar_wim1j.inf, 3)),
        sp.floor(sp.Rational(16 * c + 7 - 3 * bar_wim1jm1.sup - bar_wijm1.sup - 9 * bar_wim1j.sup, 3))
    )
def bar_wim1j_given_c_xoddyeven(c, bar_wim1jm1, bar_wijm1, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - 3 * bar_wim1jm1.inf - bar_wijm1.inf - 3 * bar_wij.inf, 9)),
        sp.floor(sp.Rational(16 * c + 7 - 3 * bar_wim1jm1.sup - bar_wijm1.sup - 3 * bar_wij.sup, 9))
    )
def bar_wijm1_given_c_xoddyeven(c, bar_wim1jm1, bar_wim1j, bar_wij):
    return sp.Interval(
        sp.ceiling(16 * c - 8 - 3 * bar_wim1jm1.inf - 9 * bar_wim1j.inf - 3 * bar_wij.inf),
        sp.floor(16 * c + 7 - 3 * bar_wim1jm1.sup - 9 * bar_wim1j.sup - 3 * bar_wij.sup)
    )
def bar_wim1jm1_given_c_xoddyeven(c, bar_wijm1, bar_wim1j, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - bar_wijm1.inf - 9 * bar_wim1j.inf - 3 * bar_wij.inf, 3)),
        sp.floor(sp.Rational(16 * c + 7 - bar_wijm1.sup - 9 * bar_wim1j.sup - 3 * bar_wij.sup, 3))
    )

def bar_wij_given_c_xevenyodd(c, bar_wim1jm1, bar_wijm1, bar_wim1j):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - 3 * bar_wim1jm1.inf - 9 * bar_wijm1.inf - bar_wim1j.inf, 3)),
        sp.floor(sp.Rational(16 * c + 7 - 3 * bar_wim1jm1.sup - 9 * bar_wijm1.sup - bar_wim1j.sup, 3))
    )
def bar_wim1j_given_c_xevenyodd(c, bar_wim1jm1, bar_wijm1, bar_wij):
    return sp.Interval(
        sp.ceiling(16 * c - 8 - 3 * bar_wim1jm1.inf - 9 * bar_wijm1.inf - 3 * bar_wij.inf),
        sp.floor(16 * c + 7 - 3 * bar_wim1jm1.sup - 9 * bar_wijm1.sup - 3 * bar_wij.sup)
    )
def bar_wijm1_given_c_xevenyodd(c, bar_wim1jm1, bar_wim1j, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - 3 * bar_wim1jm1.inf - bar_wim1j.inf - 3 * bar_wij.inf, 9)),
        sp.floor(sp.Rational(16 * c + 7 - 3 * bar_wim1jm1.sup - bar_wim1j.sup - 3 * bar_wij.sup, 9))
    )
def bar_wim1jm1_given_c_xevenyodd(c, bar_wijm1, bar_wim1j, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - 9 * bar_wijm1.inf - bar_wim1j.inf - 3 * bar_wij.inf, 3)),
        sp.floor(sp.Rational(16 * c + 7 - 9 * bar_wijm1.sup - bar_wim1j.sup - 3 * bar_wij.sup, 3))
    )

def bar_wij_given_c_xoddyodd(c, bar_wim1jm1, bar_wijm1, bar_wim1j):
    return sp.Interval(
        sp.ceiling(16 * c - 8 - 9 * bar_wim1jm1.inf - 3 * bar_wijm1.inf - 3 * bar_wim1j.inf),
        sp.floor(16 * c + 7 - 9 * bar_wim1jm1.sup - 3 * bar_wijm1.sup - 3 * bar_wim1j.sup)
    )
def bar_wim1j_given_c_xoddyodd(c, bar_wim1jm1, bar_wijm1, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - 9 * bar_wim1jm1.inf - 3 * bar_wijm1.inf - bar_wij.inf, 3)),
        sp.floor(sp.Rational(16 * c + 7 - 9 * bar_wim1jm1.sup - 3 * bar_wijm1.sup - bar_wij.sup, 3))
    )
def bar_wijm1_given_c_xoddyodd(c, bar_wim1jm1, bar_wim1j, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - 9 * bar_wim1jm1.inf - 3 * bar_wim1j.inf - bar_wij.inf, 3)),
        sp.floor(sp.Rational(16 * c + 7 - 9 * bar_wim1jm1.sup - 3 * bar_wim1j.sup - bar_wij.sup, 3))
    )
def bar_wim1jm1_given_c_xoddyodd(c, bar_wijm1, bar_wim1j, bar_wij):
    return sp.Interval(
        sp.ceiling(sp.Rational(16 * c - 8 - 3 * bar_wijm1.inf - 3 * bar_wim1j.inf - bar_wij.inf, 9)),
        sp.floor(sp.Rational(16 * c + 7 - 3 * bar_wijm1.sup - 3 * bar_wim1j.sup - bar_wij.sup, 9))
    )

In [None]:
def algorithm_1(channel_setmat):  # either ddot_cb_channel or ddot_cr_channel
    H, W = channel_setmat.shape
    downsampled_intervals = np.full((H // 2 + 2, W // 2 + 2), sp.Interval(0, 255))
    reverse = False
    iter_n = 0
    pbar = tqdm()
    while True:
        iter_n += 1
        downsampled_intervals_old = np.copy(downsampled_intervals)
        for x in (reversed(range(H)) if reverse else range(H)):
            x_even = x % 2 == 0
            i = x // 2 if x_even else (x + 1) // 2 + 1  # zero-indexing for range [-1, W//2] (python [0, W//2+2)), must +1; the i here is 1 more than the actual mathematical i
            for y in (reversed(range(W)) if reverse else range(W)):
                y_even = y % 2 == 0
                j = y // 2 if y_even else (y + 1) // 2 + 1  # zero-indexing for range [-1, H//2] (python [0, H//2+2)), must +1; the j here is 1 more than the actual mathematical i
                # c stands for either cb or cr
                if x_even and y_even:
                    diij_to_intersect = sp.Union(*[
                        bar_wij_given_c_xevenyeven(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dim1j_to_intersect = sp.Union(*[
                        bar_wim1j_given_c_xevenyeven(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dijm1_to_intersect = sp.Union(*[
                        bar_wijm1_given_c_xevenyeven(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dim1jm1_to_intersect = sp.Union(*[
                        bar_wim1jm1_given_c_xevenyeven(
                            c, 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                elif (not x_even) and y_even:
                    diij_to_intersect = sp.Union(*[
                        bar_wij_given_c_xoddyeven(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dim1j_to_intersect = sp.Union(*[
                        bar_wim1j_given_c_xoddyeven(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dijm1_to_intersect = sp.Union(*[
                        bar_wijm1_given_c_xoddyeven(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dim1jm1_to_intersect = sp.Union(*[
                        bar_wim1jm1_given_c_xoddyeven(
                            c, 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                elif x_even and (not y_even):
                    diij_to_intersect = sp.Union(*[
                        bar_wij_given_c_xevenyodd(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dim1j_to_intersect = sp.Union(*[
                        bar_wim1j_given_c_xevenyodd(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dijm1_to_intersect = sp.Union(*[
                        bar_wijm1_given_c_xevenyodd(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dim1jm1_to_intersect = sp.Union(*[
                        bar_wim1jm1_given_c_xevenyodd(
                            c, 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                else:
                    diij_to_intersect = sp.Union(*[
                        bar_wij_given_c_xoddyodd(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dim1j_to_intersect = sp.Union(*[
                        bar_wim1j_given_c_xoddyodd(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dijm1_to_intersect = sp.Union(*[
                        bar_wijm1_given_c_xoddyodd(
                            c, 
                            bar_wim1jm1=downsampled_intervals[i - 1, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                    dim1jm1_to_intersect = sp.Union(*[
                        bar_wim1jm1_given_c_xoddyodd(
                            c, 
                            bar_wijm1=downsampled_intervals[i, j - 1], 
                            bar_wim1j=downsampled_intervals[i - 1, j], 
                            bar_wij=downsampled_intervals[i, j]
                        )
                        for c in channel_setmat[x, y]
                    ])
                downsampled_intervals[i, j] = downsampled_intervals[i, j].intersect(diij_to_intersect)
                downsampled_intervals[i - 1, j] = downsampled_intervals[i - 1, j].intersect(dim1j_to_intersect)
                downsampled_intervals[i, j - 1] = downsampled_intervals[i, j - 1].intersect(dijm1_to_intersect)
                downsampled_intervals[i - 1, j - 1] = downsampled_intervals[i - 1, j - 1].intersect(dim1jm1_to_intersect)
        if np.all(downsampled_intervals == downsampled_intervals_old):
            break
        reverse = not reverse
        pbar.update()
    print(f"Converged on the {iter_n}-th iteration")
    return downsampled_intervals[1:-1, 1:-1]  # remove padded regions

In [None]:
file_name = "test1c"
print(f"Loading {file_name}_ddot_cb_channel.pkl ...")
with open(f"{file_name}_ddot_cb_channel.pkl", "rb") as f:
    ddot_cb_channel = pickle.load(f)
print(f"Running algorithm_1...")
bar_cb_channel = algorithm_1(ddot_cb_channel)
print(f"Saving to {file_name}_bar_cb_channel.pkl ...")
with open(f"{file_name}_bar_cb_channel.pkl", "wb") as f:
    pickle.dump(bar_cb_channel, f)
del bar_cb_channel
del ddot_cb_channel

In [None]:
file_name = "test2c"
print(f"Loading {file_name}_ddot_cb_channel.pkl ...")
with open(f"{file_name}_ddot_cb_channel.pkl", "rb") as f:
    ddot_cb_channel = pickle.load(f)
print(f"Running algorithm_1...")
bar_cb_channel = algorithm_1(ddot_cb_channel)
print(f"Saving to {file_name}_bar_cb_channel.pkl ...")
with open(f"{file_name}_bar_cb_channel.pkl", "wb") as f:
    pickle.dump(bar_cb_channel, f)
del bar_cb_channel
del ddot_cb_channel

## 2.3 Discrete cosine transform

In [None]:
def invert_interval_add(bar_p, alpha):
    # (5a)
    # Handling of empty intervals not yet done
    return sp.Interval(bar_p.inf - alpha, bar_p.sup - alpha)
    
def invert_intervalmat_add(bar_p_mat, alpha):
    return np.vectorize(invert_interval_add)(bar_p_mat, alpha)
    
def invert_interval_mult(bar_p, alpha):
    # (5b)
    # Handling of empty intervals not yet done
    return sp.Interval(sp.ceiling(bar_p.inf / alpha), sp.floor(bar_p.sup / alpha))

def invert_intervalmat_mult(bar_p_mat, alpha):
    return np.vectorize(invert_interval_mult)(bar_p_mat, alpha)

def invert_interval_divfloor(bar_p, alpha):
    # (5c)
    # Handling of empty intervals not yet done
    return sp.Interval(bar_p.inf * alpha, bar_p.sup * alpha + (alpha - 1))

def invert_intervalmat_divfloor(bar_p_mat, alpha):
    return np.vectorize(invert_interval_divfloor)(bar_p_mat, alpha)

def invert_intervalmat_matrixpremult(bar_y_mat, Tinv):
    # T bar_x_mat = bar_y_mat
    # bar_x_mat = Tinv bar_y_mat
    # Given interval matrix bar_y_mat, find interval matrix bar_x_mat
    
    # assumes both bar_y_mat and Tinv are square matrices of the same shape
    # (9a)
    matrix_len = bar_y_mat.shape[0]
    result = np.full((matrix_len, matrix_len), sp.EmptySet)
    for i in range(matrix_len):
        for j in range(matrix_len):
            result[i, j] = sp.Interval(
                sp.ceiling(sum(
                    [Tinv[i, k] * (bar_y_mat[k, j].inf if Tinv[i, k] >= 0 else bar_y_mat[k, j].sup) for k in range(matrix_len)]
                )),
                sp.floor(sum(
                    [Tinv[i, k] * (bar_y_mat[k, j].sup if Tinv[i, k] >= 0 else bar_y_mat[k, j].inf) for k in range(matrix_len)]
                ))
            )
    return result

def invert_intervalmat_matrixpostmult(bar_y_mat, Tinv):
    # bar_x_mat T = bar_y_mat
    # bar_x_mat = bar_y_mat Tinv
    # Given interval matrix bar_y_mat, find interval matrix bar_x_mat
    
    # assumes both bar_y_mat and Tinv are square matrices of the same shape
    # (9b), self-figured
    matrix_len = bar_y_mat.shape[0]
    result = np.full((matrix_len, matrix_len), sp.EmptySet)
    for i in range(matrix_len):
        for j in range(matrix_len):
            result[i, j] = sp.Interval(
                sp.ceiling(sum(
                    [Tinv[k, j] * (bar_y_mat[i, k].inf if Tinv[k, j] >= 0 else bar_y_mat[i, k].sup) for k in range(matrix_len)]
                )),
                sp.floor(sum(
                    [Tinv[k, j] * (bar_y_mat[i, k].sup if Tinv[k, j] >= 0 else bar_y_mat[i, k].inf) for k in range(matrix_len)]
                ))
            )
    return result
    
def invert_interval_max(bar_x, a):
    # (10a)
    # Handling of empty intervals not yet done
    return sp.Interval(-sp.oo if bar_x.inf <= a else bar_x.inf, bar_x.sup)

def invert_intervalmat_max(bar_x_mat, a):
    return np.vectorize(invert_interval_max)(bar_x_mat, a)
    
def invert_interval_min(bar_x, a):
    # (10b)
    # Handling of empty intervals not yet done
    return sp.Interval(bar_x.inf, sp.oo if bar_x.sup >= a else bar_x.sup)

def invert_intervalmat_min(bar_x_mat, a):
    return np.vectorize(invert_interval_min)(bar_x_mat, a)

def IDCT_intervalmat(bar_y_mat):
    # bar_y_mat = IDCT(bar_x_mat)
    # Given interval matrix bar_y_mat, find interval matrix bar_x_mat
    
    # result = bar_y_mat
    # result = invert_intervalmat_max(result, 0)
    # result = invert_intervalmat_min(result, 255)
    # result = invert_intervalmat_divfloor(result, alpha=2 ** 18)
    # result = invert_intervalmat_add(result, 2 ** 17)
    result = np.vectorize(
        lambda bar_y: sp.Interval(
            -sp.oo if bar_y.inf == 0 else (bar_y.inf << 18) - 131072, 
            sp.oo if bar_y.sup == 255 else (bar_y.sup << 18) + 131071
        )
    )(bar_y_mat)
    result = invert_intervalmat_matrixpostmult(result, Tinv=TTinv)
    # result = invert_intervalmat_divfloor(result, alpha=2 ** 11)
    # result = invert_intervalmat_add(result, 2 ** 10)
    result = np.vectorize(
        lambda bar_y: sp.Interval(
            bar_y.inf * 2048 - 1024, 
            bar_y.sup * 2048 + 511
        )
    )(result)
    result = invert_intervalmat_matrixpremult(result, Tinv=Tinv)
    
    return result

In [None]:
T = sp.Matrix([
    [8192, 11363, 10703, 9633, 8192, 6437, 4433, 2260],
    [8192, 9633, 4433, -2259, -8192, -11362, -10704, -6436],
    [8192, 6437, -4433, -11362, -8192, 2261, 10704, 9633],
    [8192, 2260, -10703, -6436, 8192, 9633, -4433, -11363],
    [8192, -2260, -10703, 6436, 8192, -9633, -4433, 11363],
    [8192, -6437, -4433, 11362, -8192, -2261, 10704, -9633],
    [8192, -9633, 4433, 2259, -8192, 11362, -10704, 6436],
    [8192, -11363, 10703, -9633, 8192, -6437, 4433, -2260]
])
TTinv = T.T ** -1
Tinv = T ** -1  # exact form, arbitrary precision

In [None]:
def perform_dct_blockwise(channel, file_name):  # bar_y_channel, bar_cb_channel, or bar_cr_channel
    H, W = channel.shape
    assert H % 8 == 0 and W % 8 == 0
    count_blocks_vertical = H // 8
    count_blocks_horizontal = W // 8

    # SLOW SOLUTION
    dequantised_dct_coefficient_intervalchannel = np.full(channel.shape, None)
    for block_i in trange(count_blocks_vertical):
        for block_j in trange(count_blocks_horizontal, leave=False):
            X = channel[(8 * block_i):(8 + 8 * block_i), (8 * block_j):(8 + 8 * block_j)]
            dequantised_dct_coefficient_intervalchannel[(8 * block_i):(8 + 8 * block_i), (8 * block_j):(8 + 8 * block_j)] = IDCT_intervalmat(X)
    
    with open(f"{file_name}_dequantised_dct_coefficient_intervalchannel.pkl", "wb") as f:
        pickle.dump(dequantised_dct_coefficient_intervalchannel, f)

    # MULTIPROCESSING AND CONCURRENT.FUTURES SOLUTIONS ARE BROKEN
    # def IDCT_wrapper(X):
    #     return IDCT_intervalmat(np.vectorize(lambda v: sp.FiniteSet(v))(X))
    
    # progress = tqdm(total=count_blocks_vertical * count_blocks_horizontal)
    # def callback_per_block(future, block_i, block_j):
    #     dequantised_dct_coefficient_intervalchannel[(8 * block_i):(8 + 8 * block_i), (8 * block_j):(8 + 8 * block_j)] = future.result()
    #     progress.update()
    
    # with concurrent.futures.ProcessPoolExecutor() as executor:
    #     for block_i in range(count_blocks_vertical):
    #         for block_j in range(count_blocks_horizontal):
    #             executor.submit(IDCT_wrapper, channel[(8 * block_i):(8 + 8 * block_i), (8 * block_j):(8 + 8 * block_j)]).add_done_callback(
    #                 lambda future: callback_per_block(future, block_i, block_j)
    #             )

In [None]:
perform_dct_blockwise(bar_y_channel, file_name)

## 2.4 Determining possible quality factors

In [None]:
print(f"[{datetime.datetime.now().isoformat()}] Opening {file_name}_dequantised_dct_coefficient_intervalchannel.pkl ...")
with open(f"{file_name}_dequantised_dct_coefficient_intervalchannel.pkl", "rb") as f:
    print(f"[{datetime.datetime.now().isoformat()}] Unpickling {file_name}_dequantised_dct_coefficient_intervalchannel.pkl ...")
    dequantised_dct_coefficient_intervalchannel = pickle.load(f)
    print(f"[{datetime.datetime.now().isoformat()}] Done unpickling")

In [None]:
def quality_is_possible(candidate_quality, minimal_infimum, maximal_supremum, dequantised_dct_coefficient_intervalchannel):
    y_quantisation_table = generate_y_quantisation_table_given_quality(candidate_quality)
    # multiples = np.vectorize(lambda v: generate_multiples_within_bounds(v, -1023, 1024))(y_quantisation_table)

    H, W = dequantised_dct_coefficient_intervalchannel.shape
    assert H % 8 == 0 and W % 8 == 0
    count_blocks_vertical = H // 8
    count_blocks_horizontal = W // 8
    for block_i in trange(count_blocks_vertical, leave=False):
        for block_j in trange(count_blocks_horizontal, leave=False):
            bar_block = dequantised_dct_coefficient_intervalchannel[(8 * block_i):(8 + 8 * block_i), (8 * block_j):(8 + 8 * block_j)]
            # This is an interval(bar) matrix 
            for i in reversed(range(8)):
                for j in reversed(range(8)):
                    inf = bar_block[i, j].inf
                    sup = bar_block[i, j].sup
                    q = y_quantisation_table[i, j]
                    if round(inf / q) == round(sup / q) and ((inf < 0 and sup % q != 0) or (inf >= 0 and inf % q != 0)):
                        return False
                    # for multiple in multiples[i, j]:
                    #     if bar_block[i, j].contains(multiple):
                    #         continue
                    #     return False  # all multiples not present; early terminate
    return True

In [None]:
possible_qualities = []
print(f"[{datetime.datetime.now().isoformat()}] Starting main loop...")
for candidate_quality in trange(1, 101):
    # print(f"[{datetime.datetime.now().isoformat()}] Testing candidate quality {candidate_quality}...")
    if quality_is_possible(candidate_quality, minimal_infimum, maximal_supremum, dequantised_dct_coefficient_intervalchannel):
        print(f"[{datetime.datetime.now().isoformat()}] quality {candidate_quality} is possible")
        possible_qualities.append(candidate_quality)
    else:
        print(f"[{datetime.datetime.now().isoformat()}] quality {candidate_quality} disqualified")