Our own image compression schemes!

In [None]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import IPython.display

from cued_sf2_lab.familiarisation import load_mat_img, plot_image
from compression_schemes.dwt_funcs import *
from cued_sf2_lab.familiarisation import load_mat_img, plot_image
from compression_schemes.subjective_quality import ssim
from cued_sf2_lab.jpeg_dwt import *

In [None]:
def step_from_target_bits_DWT(X: np.ndarray,
                              n: int,
                              target_bits: float,
                              rise_ratio=1.0, 
                              lo: float = 1.0,
                              hi: float = 50.0,
                              tol_bits: float = 500.0,
                              dcbits: int = 9,
                              max_iter: int = 100):


    # binary search for Δ so that bits ≈ target_bits
    for _ in range(max_iter):
        step_multiplier = 0.5 * (lo + hi)
        vlc, dhufftab, totalbits = jpegenc_dwt(X, n,  step_multiplier, rise_ratio, dcbits=dcbits, opthuff=True, log=False)

        if abs(totalbits - target_bits) <= tol_bits:
            break

        if totalbits > target_bits:     # too many bits ⇒ need larger Δ
            lo = step_multiplier
        else:                       # too few bits ⇒ need smaller Δ
            hi = step_multiplier
    else:
        print("Warning: max_iter reached without hitting target_bits")

    return step_multiplier

In [None]:
X, _ = load_mat_img(img='flamingo.mat', img_info='X')
X = X-128.0
step_multiplier = step_from_target_bits_DWT(X=X,
                              n=5,
                              target_bits= 36000,
                              rise_ratio=1.0, 
                              lo = 1.0,
                              hi = 50.0,
                              tol_bits = 100.0,
                              dcbits = 9,
                              max_iter = 100)
print(f"optimal step size {step_multiplier}")
vlc, dhufftab, totalbits = jpegenc_dwt(X, n=5, step_multiplier=step_multiplier, rise_ratio=1.0, opthuff=True, dcbits=9, log=True)
print(totalbits)

In [None]:
Z = jpegdec_dwt(vlc, 5, step_multiplier, rise_ratio = 1.0, hufftab = dhufftab, dcbits = 9)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))
plot_image(X, ax=ax1)
ax1.set(title="X")
plot_image(Z, ax=ax2)
ax2.set(title="Z")
print(ssim(X, Z))

In [None]:
# List of image filenames
image_list = ['flamingo.mat', 'lighthouse.mat', 'bridge.mat']  
rms_targets = np.linspace(6, 14, 10)

# Start plotting
plt.figure()

for img_name in image_list:
    # Load and preprocess
    X, _ = load_mat_img(img=img_name, img_info='X')
    X = X - 128.0

    bits = []
    rms_errors = []
    #estimates = []

    for target_rms in rms_targets:
        opt_step, scaled, _, _,_, _ = diff_step_sizes(X, 256, n=5, target_rms=target_rms, rise_ratio=1.0)
        #encoding
        vlc, dhufftab, totalbits = jpegenc_dwt(X, 5, dcbits=9, step_multiplier=opt_step, rise_ratio=1.0, log=False, opthuff=True)
        bits.append(totalbits)
        #comparing to estimate
        #Y = nlevdwt(X, n=5)
        #Yq, dwtent= quantdwt(Y, scaled, rise_ratio=1.0)
        #Yq = np.round(Yq).astype(int)
        #_, estimated_bits = compression_ratio_for_DWT(X, Yq, dwtent)
        #estimates.append(int(round(estimated_bits)))
        #decoding
        Z = jpegdec_dwt(vlc, 5, step_multiplier = opt_step, rise_ratio = 1.0, hufftab = dhufftab, dcbits = 9, log=False)
        actual_rms = np.std(X-Z)
        rms_errors.append(actual_rms)
        print(f"actual rms:{actual_rms},target rms: {target_rms}")
        #print(f"actual no. bits {totalbits}, estimated no.bits {estimated_bits:.0f}")
    # Plot curve for this image
    plt.plot(rms_errors, bits, 'o-', label=img_name.replace('.mat', ''))
    #plt.plot(rms_errors, estimates, '--', label = f"{img_name.replace('.mat', '')} (estimated bits)")
# Horizontal line at 5 kB
plt.axhline(y=40960, color='red', linestyle='--', label='5 kB = 40960 bits')

# Finalize plot
plt.xlabel('RMS error')
plt.ylabel('Total bits')
plt.title('Rate-Distortion Curve (DWT Compression)')
plt.grid(True)
plt.legend()
plt.show()


Testing an LBT scheme!

In [None]:
from compression_schemes.lbt_functions import *
from cued_sf2_lab.jpeg_lbt import *

In [None]:
X, _ = load_mat_img(img='lighthouse.mat', img_info='X')
X = X-128.0
vlc, dhufftab, totalbits = jpegenc_lbt(X, dcbits=9, qstep=45, rise_ratio=1.0, N=16, M=16,log=True, opthuff=True)
print(totalbits)

In [None]:
Zl = jpegdec_lbt(vlc, qstep=45, rise_ratio = 1.0, N=16, M=16, hufftab = dhufftab, dcbits = 9)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))
plot_image(X, ax=ax1)
ax1.set(title="X")
plot_image(Zl, ax=ax2)
ax2.set(title="Z")
print(ssim(X, Zl))
print(np.std(X-Zl))

In [None]:
target_err = 9.0  # for example, match the RMS error to this
s = np.sqrt(2)
rise_ratio = 1.0
dcbits = 8

results = []

Ns = [4, 8, 16]

for N in Ns:
    Ms = [N, 2*N, 4*N]
    
    # Find qstep that matches the target RMS error for this N
    qstep, actual_rms = find_step_LBT(X, target_err, s, N, rise_ratio)
    print(f"\n>>> N={N}: matched qstep={qstep:.3f}, RMS={actual_rms:.3f}")

    for M in Ms:
        try:
            # Encode using matched step size
            vlc, dhufftab, totalbits = jpegenc_lbt(
                X, qstep=qstep, rise_ratio=rise_ratio, N=N, M=M,
                dcbits=dcbits, opthuff=True, log=False
            )

            # Decode
            Z = jpegdec_lbt(
                vlc, qstep=qstep, rise_ratio=rise_ratio, N=N, M=M,
                dcbits=dcbits, hufftab=dhufftab, log=False
            )

            # Compute quality metrics
            ssim_score = ssim(X, Z)
            rms_err = np.std(X - Z)

            print(f"N={N}, M={M} → bits={totalbits}, SSIM={ssim_score:.4f}")

            results.append((N, M, totalbits, ssim_score, rms_err))

        except Exception as e:
            print(f"Error at N={N}, M={M}: {e}")


In [None]:
vlc, dhufftab, totalbits = jpegenc_lbt(
                X, qstep=45.2, rise_ratio=1.0, N=4, M=16,
                dcbits=8, opthuff=True, log=False)

Zl_4 = jpegdec_lbt(vlc, qstep=45.2, rise_ratio = 1.0, N=4, M=16, hufftab = dhufftab, dcbits = 8)

vlc, dhufftab, totalbits = jpegenc_lbt(
                X, qstep=42.3, rise_ratio=1.0, N=8, M=16,
                dcbits=8, opthuff=True, log=False)


Zl_8 = jpegdec_lbt(vlc, qstep=42.3, rise_ratio = 1.0, N=8, M=16, hufftab = dhufftab, dcbits = 8)

vlc, dhufftab, totalbits = jpegenc_lbt(
                X, qstep=36.7, rise_ratio=1.0, N=16, M=16,
                dcbits=8, opthuff=True, log=False)


Zl_16 = jpegdec_lbt(vlc, qstep=36.7, rise_ratio = 1.0, N=16, M=16, hufftab = dhufftab, dcbits = 8)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8,4))
plot_image(Zl_4, ax=ax1)
ax1.set(title="N=4, M=16")

plot_image(Zl_8, ax=ax2)
ax2.set(title="N=8, M=16")

plot_image(Zl_16, ax=ax3)
ax3.set(title="M=16")

Testing with number of bits comparison to DWT

In [None]:
def step_from_target_bits_LBT(X: np.ndarray,
                              target_bits: float,
                              rise_ratio=1.0, 
                              lo: float = 1.0,
                              hi: float = 50.0,
                              tol_bits: float = 500.0,
                              dcbits: int = 8,
                              N: int=8,
                              M: int=8,
                              max_iter: int = 100):


    # binary search for Δ so that bits ≈ target_bits
    for _ in range(max_iter):
        qstep = 0.5 * (lo + hi)
        vlc, dhufftab, totalbits = jpegenc_lbt(X, qstep = qstep, rise_ratio=rise_ratio, N=N, M=M, dcbits=dcbits, opthuff=True, log=False)

        if abs(totalbits - target_bits) <= tol_bits:
            break

        if totalbits > target_bits:     # too many bits ⇒ need larger Δ
            lo = qstep
        else:                       # too few bits ⇒ need smaller Δ
            hi = qstep
    else:
        print("Warning: max_iter reached without hitting target_bits")

    return qstep

In [None]:
X, _ = load_mat_img(img='flamingo.mat', img_info='X')
X = X-128.0
qstep = step_from_target_bits_LBT(X,
                              target_bits=36000,
                              rise_ratio=1.0, 
                              lo = 1.0,
                              hi = 50.0,
                              tol_bits = 500.0,
                              dcbits = 8,
                              N=8,
                              M=16,
                              max_iter = 100)
print(f"optimal step size {qstep}")
vlc, dhufftab, totalbits = jpegenc_lbt(X, qstep = qstep, rise_ratio=rise_ratio, N=8, M=16, dcbits=dcbits, opthuff=True, log=True)
print(totalbits)

In [None]:
Zl = jpegdec_lbt(vlc, qstep=qstep, rise_ratio = 1.0, N=8, M=16, hufftab = dhufftab, dcbits = 8)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))
plot_image(X, ax=ax1)
ax1.set(title="X")
plot_image(Zl, ax=ax2)
ax2.set(title="Zl")
print(ssim(X, Zl))

Testing suppression now

In [None]:
def step_from_target_bits_LBT_with_suppression(X: np.ndarray,
                              target_bits: float,
                              rise_ratio=1.0, 
                              lo: float = 1.0,
                              hi: float = 50.0,
                              tol_bits: float = 500.0,
                              dcbits: int = 8,
                              N: int=8,
                              M: int=8,
                              keep_fraction: float=1.0,
                              max_iter: int = 100):


    # binary search for Δ so that bits ≈ target_bits
    for _ in range(max_iter):
        qstep = 0.5 * (lo + hi)
        vlc, dhufftab, totalbits = jpegenc_lbt_with_suppression(X, qstep = qstep, rise_ratio=rise_ratio, N=N, M=M, keep_fraction = keep_fraction, dcbits=dcbits, opthuff=True, log=False)

        if abs(totalbits - target_bits) <= tol_bits:
            break

        if totalbits > target_bits:     # too many bits ⇒ need larger Δ
            lo = qstep
        else:                       # too few bits ⇒ need smaller Δ
            hi = qstep
    else:
        print("Warning: max_iter reached without hitting target_bits")

    return qstep

In [None]:
qstep = step_from_target_bits_LBT_with_suppression(X,
                              target_bits=36000,
                              rise_ratio=1.0, 
                              lo = 1.0,
                              hi = 50.0,
                              tol_bits = 500.0,
                              dcbits = 8,
                              N=8,
                              M=16,
                              keep_fraction=0.70,
                              max_iter = 100)
print(f"optimal step size {qstep}")
vlc, dhufftab, totalbits = jpegenc_lbt_with_suppression(X, qstep = qstep, rise_ratio=1.0, N=8, M=16, keep_fraction = 0.70, dcbits=8, opthuff=True, log=True)
print(totalbits)

In [None]:
Zl_sup = jpegdec_lbt(vlc, qstep=qstep, rise_ratio = 1.0, N=8, M=16, hufftab = dhufftab, dcbits = 8)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))
plot_image(X, ax=ax1)
ax1.set(title="X")
plot_image(Zl_sup, ax=ax2)
ax2.set(title="Zl_sup")
print(ssim(X, Zl_sup))

Testing the three schemes on the bridge image

In [None]:
Xb, _ = load_mat_img(img='2023.mat', img_info='X')
Xb = Xb-128.0

In [None]:
step_multiplier = step_from_target_bits_DWT(X=Xb,
                              n=5,
                              target_bits= 36000,
                              rise_ratio=1.0, 
                              lo = 1.0,
                              hi = 50.0,
                              tol_bits = 100.0,
                              dcbits = 9,
                              max_iter = 100)
print(f"optimal step size {step_multiplier}")
vlc, dhufftab, totalbits = jpegenc_dwt(Xb, n=5, step_multiplier=step_multiplier, rise_ratio=1.0, opthuff=True, dcbits=9, log=False)
print(totalbits)

Za = jpegdec_dwt(vlc, 5, step_multiplier, rise_ratio = 1.0, hufftab = dhufftab, dcbits = 9, log=False)

In [None]:
qstep = step_from_target_bits_LBT_with_suppression(Xb,
                              target_bits=36000,
                              rise_ratio=1.0, 
                              lo = 1.0,
                              hi = 70.0,
                              tol_bits = 100.0,
                              dcbits = 8,
                              N=8,
                              M=16,
                              keep_fraction=1.0,
                              max_iter = 100)
print(f"optimal step size {qstep}")
vlc, dhufftab, totalbits = jpegenc_lbt_with_suppression(Xb, qstep = qstep, rise_ratio=1.0, N=8, M=16, keep_fraction = 1.0, dcbits=8, opthuff=True, log=False)
Zb = jpegdec_lbt(vlc, qstep=qstep, rise_ratio = 1.0, N=8, M=16, hufftab = dhufftab, dcbits = 8, log=False)
print(totalbits)

In [None]:
qstep = step_from_target_bits_LBT_with_suppression(Xb,
                              target_bits=36000,
                              rise_ratio=1.0, 
                              lo = 1.0,
                              hi = 70.0,
                              tol_bits = 100.0,
                              dcbits = 8,
                              N=8,
                              M=16,
                              keep_fraction=0.75,
                              max_iter = 100)
print(f"optimal step size {qstep}")
vlc, dhufftab, totalbits = jpegenc_lbt_with_suppression(Xb, qstep = qstep, rise_ratio=1.0, N=8, M=16, keep_fraction = 0.75, dcbits=8, opthuff=True, log=False)
Zc = jpegdec_lbt(vlc, qstep=qstep, rise_ratio = 1.0, N=8, M=16, hufftab = dhufftab, dcbits = 8, log=False)
print(totalbits)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8,4))
plot_image(Za, ax=ax1)
ax1.set(title="DWT")
plot_image(Zb, ax=ax2)
ax2.set(title="LBT")
plot_image(Zc, ax=ax3)
ax3.set(title="LBT, 0.75 suppressed")
rms_a = np.std(Xb-Za)
rms_b = np.std(Xb-Zb)
rms_c = np.std(Xb-Zc)
ssim_a = ssim(Xb, Za)
ssim_b = ssim(Xb, Zb)
ssim_c = ssim(Xb, Zc)
print(f"The rms error for DWT based compression was {rms_a:.2f}, SSIM:{ssim_a:.2f}")
print(f"The rms error for LBT based compression was {rms_b:.2f}, SSIM:{ssim_b:.2f}")
print(f"The rms error for LBT with suppression was {rms_c:.2f}, SSIM:{ssim_c:.2f}")