# Problem Statement

In [None]:
import numpy as np
import pywt
from scipy.special import psi, gamma # Needed for the M-step
from scipy.optimize import brentq
from scipy.integrate import quad # The key for numerical integration in the E-step



import numpy as np
import pywt
from scipy.stats import median_abs_deviation

def get_dHH1_subband(noisy_image, wavelet_name='db4'):

    coeffs = pywt.wavedec2(noisy_image, wavelet=wavelet_name, level=1)
    # Unpack the detail coefficients from the tuple
    cA, (cH1, cV1, dHH_1) = coeffs
    
    return dHH_1


def estimate_noise_std(noisy_image, wavelet_name='db4'):
    # extract dHH_1 subband 
    dHH_1 = get_dHH1_subband(noisy_image, wavelet_name=wavelet_name)
    #calculate MAD, we can take this shortcut bc awgn so we assume median is approx 0.
    mad_val = np.median(np.abs(dHH_1))
    #compute sigma_epsilon 
    sigma_epsilon = mad_val / 0.6745
    
    return sigma_epsilon

def estimate_alpha_c(noisy_image, sigma_eps, wavelet_name='db4'):
    #calculate sample mean
    dHH_1 = get_dHH1_subband(noisy_image, wavelet_name=wavelet_name)
    d = dHH_1.flatten()
    mean_val = np.mean(d)
    N = len(d)

    m_2 = np.mean((d-mean_val)**2)
    m_4 = np.mean((d-mean_val)**4)

    k_2 = m_2*(N/(N-1));
    k_4 = N**2*((N+1)*m_4-3*(N-1)*m_2**2)/((N-1)*(N-2)*(N-3))

    k2_adjust = max(k_2-sigma_eps**2,0)

    alpha = 3*(k2_adjust)**2/k_4
    c = k2_adjust/alpha;

    return [alpha, c]
def initial_estimate(noisy_image, wavelet_name='db4'):
  d = get_dHH1_subband(noisy_image, wavelet_name)
  sigma_eps = estimate_noise_std(noisy_image, wavelet_name)
  alpha, c = estimate_alpha_c(noisy_image, sigma_eps, wavelet_name)
  return [d,alpha, c, sigma_eps]


In [None]:
# --- 2. E-Step Integrands (Numerator and Denominator) ---
from multiprocessing import Pool
import os 

def integrand_den(u, d, alpha, c, sigma_eps):
    """
    Integrand for the denominator of E[U|d] and E[log(U)|d] (Eq. 25 & 26).
    
    Note: The original paper uses K_nu(z) in the prior term, but the equations 
    (25 & 26) have been simplified into the shown exponential/power form 
    specific to the BKF distribution in a GSM form. We use the simplified form here.
    """
    # u is the integration variable (latent scale variable)
    # d is the current noisy coefficient
    # The term u+sigma_sq comes from the Gaussian Likelihood and the conditional variance.
    
    # Exponent term (Gaussian Likelihood and Power Term from the Gamma/Inverse Gamma part of BKF)
    exponent = (d**2 / (2 * (u + sigma_eps**2))) - (u / c)
    
    # Power term from the prior distribution
    power_term = u**(alpha - 1)
    
    # Denominator term (Likelihood and conditional variance)
    denom_term = np.sqrt(u + sigma_eps**2)
    
    # The function being integrated
    return (power_term / denom_term) * np.exp(exponent)


def integrand_num_mu(u, d, alpha, c, sigma_eps):
    exponent = (d**2 / (2 * (u + sigma_eps**2))) - (u / c)
    power_term = u**(alpha)
    denom_term = np.sqrt(u + sigma_eps**2)
    return (power_term / denom_term) * np.exp(exponent)


def integrand_num_phi(u, d, alpha, c, sigma_sq):
    #conditional for log(0) = -inf. 
    if u == 0:
        return 0.0 
    #its just the denominator multiplied bylog(u)
    return np.log(u) * integrand_den(u, d, alpha, c, sigma_sq)


def e_step_integral(d, alpha, c, sigma_epsilon):
 
    
    I_den, err_den = quad(integrand_den, 0, np.inf, args=(d, alpha, c, sigma_epsilon))
    if I_den <= 1e-15: 
        return 0.0, 0.0 # Avoid division by zero
        
    # integrate numerator for M_U(1)
    I_mu_num, err_mu_num = quad(integrand_num_mu, 0, np.inf, args=(d, alpha, c, sigma_epsilon))
    
    # integrate the numerator for phi_U(1)
    I_phi_num, err_phi_num = quad(integrand_num_phi, 0, np.inf, args=(d, alpha, c, sigma_epsilon))
    
    M_U_1 = I_mu_num / I_den 
    phi_U_1 = I_phi_num / I_den 
    
    return M_U_1, phi_U_1

# Vectorize the integral function to apply it to all coefficients in a subband
v_e_step_integral = np.vectorize(e_step_integral)


def e_step_wrapper(d_params):
    """Unpacks parameters and calls the integral function."""
    d, alpha, c, sigma_epsilon = d_params
    return e_step_integral(d, alpha, c, sigma_epsilon)

def parallel_e_step(D, alpha, c, sigma_epsilon):
    """
    Parallelizes the E-step integral calculation over the array of coefficients D.
    """
    # Create the list of parameter tuples for all coefficients
    params_list = [(d, alpha, c, sigma_epsilon) for d in D]
    
    # Use all but one core
    n_cores = os.cpu_count() - 1 
    if n_cores < 1: n_cores = 1
    
    print(f"Starting parallel integration on {len(D)} coefficients using {n_cores} cores...")
    
    with Pool(n_cores) as pool:
        results = pool.map(e_step_wrapper, params_list)
        
    # Convert list of tuples to separate NumPy arrays
    M_U_1_array = np.array([res[0] for res in results])
    phi_U_1_array = np.array([res[1] for res in results])
    
    return M_U_1_array, phi_U_1_array




In [None]:
def e_step_integral_vectorized_cpu(d, alpha, c, sigma_eps, u_max=50, n_points=1000):
    u = np.linspace(0, u_max, n_points)[:, None]  # shape (n_points, 1)
    d = d[None, :]  # shape (1, n_coeffs)

    denom = np.sqrt(u + sigma_eps**2)
    exponent = (d**2 / (2*(u + sigma_eps**2))) - (u / c)

    integrand_den = (u**(alpha-1) / denom) * np.exp(exponent)
    integrand_num_mu = (u**alpha / denom) * np.exp(exponent)
    integrand_num_phi = np.log(np.maximum(u, 1e-15)) * integrand_den

    I_den = np.trapz(integrand_den, u[:,0], axis=0)
    I_mu_num = np.trapz(integrand_num_mu, u[:,0], axis=0)
    I_phi_num = np.trapz(integrand_num_phi, u[:,0], axis=0)

    M_U = I_mu_num / I_den
    phi_U = I_phi_num / I_den
    return M_U, phi_U


In [None]:
# M-Step and the Full EM Loop

def m_step_alpha_root_func(alpha, phi_U_avg, M_U_avg):
    if alpha <= 0:
        return np.inf 
        
    # LHS of Eq. 28 
    LHS = psi(alpha) - np.log(alpha)
    # RHS of Eq. 28 (we are trying to find roots)
    RHS = phi_U_avg - np.log(M_U_avg)

    return LHS - RHS

def em_bessel_k_form_noisy(d_subband, alpha_init, c_init, sigma_epsilon, max_iter=100, tol=1e-5):
    """
    The main EM algorithm for a single wavelet subband (d).
    
    Args:
        d_subband: noisy wavelet coefficients in a single subband (2D numpy array).
        alpha_init, c_init: Initial guesses for alpha and c.
        sigma_epsilon: Estimated noise standard deviation from the MAD formula 
        
    """
    m = len(d_subband)
    alpha = alpha_init
    c = c_init
    
    # Flatten the subband data for iteration
    d_flat = d_subband.ravel() 
    
    for t in range(max_iter):
        alpha_prev = alpha
        c_prev = c
        
        # e step 
        M_U_i, phi_U_i = e_step_integral_vectorized_cpu(d_flat, alpha, c, sigma_epsilon)
        
        M_U_sum = np.sum(M_U_i)
        phi_U_sum = np.sum(phi_U_i)
        
        M_U_avg = M_U_sum / m
        phi_U_avg = phi_U_sum / m

        # m step 
        
        # Solve for alpha^(t+1) using root-finding
        a = max(alpha * 0.1, 0.01) 
        b = alpha * 10.0
        
        try:
            alpha_next = brentq(
                m_step_alpha_root_func, 
                a=a, b=b, 
                args=(phi_U_avg, M_U_avg)
            )
        except Exception:
            # Fallback if root finding fails 
            alpha_next = alpha
            
        # 2. Solve for c^(t+1) (eq. 27)
        c_next = (1 / alpha_next) * M_U_avg
        
        # --- Check for Convergence ---
        alpha_diff = np.abs(alpha_next - alpha)
        c_diff = np.abs(c_next - c)
        
        alpha, c = alpha_next, c_next
        
        if alpha_diff < tol and c_diff < tol:
            # print(f"EM converged after {t+1} iterations.")
            break
            
    return alpha, c

In [4]:
#closed form expressions

def A(d, c, sigma_eps):
    return np.abs(d) - np.sqrt(2/c)*(sigma_eps**2)

def lambda_func(c, alpha,sigma_eps):
    factor = np.sqrt(2) * sigma_eps
    first_term = np.sqrt(2*(1-alpha))
    second_term = sigma_eps / np.sqrt(c)
    return factor * (first_term + second_term)

def map_estimator(d,A,alpha,sigma_eps, lambda_func):
    if np.abs(d) <= lambda_func:
        return 0.0
    else:
        return np.sign(d)/2 * (A+np.sqrt(A**2 + 4*sigma_eps**2*(alpha-1)))
# Vectorize the MAP estimator for array inputs
v_map_estimator = np.vectorize(map_estimator)

In [5]:
#Function that adds Gaussian Noise or Poisson Noise for now (not sure if poisson works right)
def add_noise(img, SNR_dB, noise_type='gaussian'):
    signal_power = np.mean(img ** 2)
    noise_power = signal_power / (10 ** (SNR_dB / 10))
    if noise_type == 'poisson':
      lamda = np.power(10,SNR_dB/10)/np.mean(img)
      scaled = img * lamda
      noisy_counts = np.random.poisson(scaled)
      noisy = noisy_counts / lamda
    else:
        noise = np.random.normal(0, np.sqrt(noise_power), img.shape)
        noisy = img + noise
    return np.clip(noisy, 0, 1)

In [None]:
# Importing an image and adding noise
from skimage import io, util, img_as_float
from skimage.color import rgb2gray
import matplotlib.pyplot as plt
url = 'https://images.unsplash.com/photo-1503023345310-bd7c1de61c7d'
color_img = img_as_float(io.imread(url))
gray_img = rgb2gray(color_img)

SNR_dB = -10
noisy_img = add_noise(gray_img, SNR_dB, 'poisson')
original_std = round(estimate_noise_std(gray_img),4)
noisy_std = round(estimate_noise_std(noisy_img),4)
[d1,alpha1,c1,sigma_eps1] = initial_estimate(gray_img)
[d2, alpha2, c2, sigma_eps2] = initial_estimate(noisy_img)
[alpha_fin, c_fin] = em_bessel_k_form_noisy(d2, alpha2, c2, sigma_eps2,1)
A_val = A(d2,c_fin,sigma_eps2)
lambda_val = lambda_func(c_fin,alpha_fin,sigma_eps2)
d_estim = v_map_estimator(d2,A_val, alpha_fin, sigma_eps2,lambda_val)


plt.figure(figsize=(20,10))
plt.subplot(1,2,1)
plt.imshow(gray_img,cmap='gray')
plt.title(fr"Original  Estimated $\sigma^2= ${original_std}")
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(noisy_img,cmap='gray')
plt.title(fr"Noisy (SNR={SNR_dB} dB)  Estimated $\sigma^2= ${noisy_std}")
plt.axis('off')
plt.show()



Starting parallel integration on 2576184 coefficients using 11 cores...
