In [None]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import IPython.display

from cued_sf2_lab.familiarisation import load_mat_img, plot_image
from compression_schemes.dwt_funcs import *
from cued_sf2_lab.familiarisation import load_mat_img, plot_image
from compression_schemes.subjective_quality import ssim,vif_index
from cued_sf2_lab.jpeg import huffdflt as huffdflt_ac, huffgen, runampl, diagscan, huffenc
from cued_sf2_lab.jpeg_dwt_diff_step import jpegenc_dwt, jpegdec_dwt

In [None]:
# Load & zero-mean
X, _ = load_mat_img(img='bridge.mat', img_info='X')
X    = X - 128.0

In [None]:
def step_from_target_bits_DWT(X: np.ndarray,
                              n: int,
                              target_bits: float,
                              k: float,
                              rise_ratio=1.0, 
                              lo: float = 1.0,
                              hi: float = 50.0,
                              tol_bits: float = 500.0,
                              dcbits: int = 9,
                              max_iter: int = 100):

       # helper: forward LBT analysis, quantise, count bits
       

    # binary search for Δ so that bits ≈ target_bits
    for _ in range(max_iter):
        step_multiplier = 0.5 * (lo + hi)
        vlc, dhufftab, totalbits = jpegenc_dwt(X, n,  k = k, step_multiplier=step_multiplier, rise_ratio=1.0, dcbits=dcbits, opthuff=True, log=False)

        if abs(totalbits - target_bits) <= tol_bits:
            print(totalbits)
            break

        if totalbits > target_bits:     # too many bits ⇒ need larger Δ
            lo = step_multiplier
        else:                       # too few bits ⇒ need smaller Δ
            hi = step_multiplier
    else:
        print("Warning: max_iter reached without hitting target_bits")

    return step_multiplier

In [None]:
step_multiplier = step_from_target_bits_DWT(X=X,
                              n=5,
                              target_bits= 36000,
                              k=0.6,
                              rise_ratio=1.0, 
                              lo = 1.0,
                              hi = 50.0,
                              tol_bits = 100.0,
                              dcbits = 9,
                              max_iter = 100)
print(f"optimal step size {step_multiplier}")
vlc, dhufftab, totalbits = jpegenc_dwt(X, n=5, k = 0.6, step_multiplier=step_multiplier, rise_ratio=1.0, dcbits = 9, opthuff=True, log=True)
print(totalbits)

In [None]:
Z = jpegdec_dwt(vlc, 5, k=0.6, step_multiplier=step_multiplier, rise_ratio = 1.0, hufftab = dhufftab, dcbits = 9)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))
plot_image(X+128, ax=ax1)
ax1.set(title="X")
plot_image(Z+128, ax=ax2)
ax2.set(title="Z")
print(ssim(X, Z))
print(np.std((X - Z)))

In [None]:
k_values = [0.5, 0.6, 0.7, 0.8, 0.85, 0.9]
target_bits = 36000
results = []

for k in k_values:
    print(f"\nOptimising for k = {k}")

    try:
        # Find step multiplier to meet target bit rate
        step_multiplier = step_from_target_bits_DWT(
            X=X,
            n=5,
            target_bits=target_bits,
            k=k,
            rise_ratio=1.0,
            lo=1.0,
            hi=50.0,
            tol_bits=100.0,
            dcbits=9,
            max_iter=100
        )

        # Encode
        vlc, dhufftab, totalbits = jpegenc_dwt(
            X, n=5, k=k, step_multiplier=step_multiplier,
            rise_ratio=1.0, dcbits=9, opthuff=True, log=False
        )

        # Decode
        Z = jpegdec_dwt(
            vlc, 5, k=k, step_multiplier=step_multiplier,
            rise_ratio=1.0, hufftab=dhufftab, dcbits=9
        )

        # Compute quality metrics
        rms = np.std(X - Z)
        quality_ssim = ssim(X, Z)

        print(f"  Bits: {totalbits}, RMS error: {rms:.3f}, SSIM: {quality_ssim:.4f}")

        results.append({
            'k': k,
            'step_multiplier': step_multiplier,
            'bits': totalbits,
            'rms': rms,
            'ssim': quality_ssim
        })

    except Exception as e:
        print(f"  Failed for k = {k}: {e}")

# Extract results
ks = [r['k'] for r in results]
rms_vals = [r['rms'] for r in results]
ssim_vals = [r['ssim'] for r in results]

# Plot RMS error vs k
plt.figure(figsize=(6, 4))
plt.plot(ks, rms_vals, marker='o')
plt.xlabel('k (frequency weighting factor)')
plt.ylabel('RMS error')
plt.title('RMS Error vs k (fixed bit rate)')
plt.grid(True)
plt.show()

# Plot SSIM vs k
plt.figure(figsize=(6, 4))
plt.plot(ks, ssim_vals, marker='s')
plt.xlabel('k (frequency weighting factor)')
plt.ylabel('SSIM')
plt.title('SSIM vs k (fixed bit rate)')
plt.grid(True)
plt.show()
