# Problem Statement

In [None]:
import numpy as np
import pywt
from scipy.special import psi, gamma # Needed for the M-step
from scipy.optimize import brentq
from scipy.integrate import quad # The key for numerical integration in the E-step



import numpy as np
import pywt
from scipy.stats import median_abs_deviation

def get_dHH1_subband(noisy_image, wavelet_name='db4'):

    coeffs = pywt.wavedec2(noisy_image, wavelet=wavelet_name, level=1)
    # Unpack the detail coefficients from the tuple
    cA, (cH1, cV1, dHH_1) = coeffs
    
    return dHH_1


def estimate_noise_std(noisy_image, wavelet_name='db4'):
    # extract dHH_1 subband 
    dHH_1 = get_dHH1_subband(noisy_image, wavelet_name=wavelet_name)
    #calculate MAD, we can take this shortcut bc awgn so we assume median is approx 0.
    mad_val = np.median(np.abs(dHH_1))
    #compute sigma_epsilon 
    sigma_epsilon = mad_val / 0.6745
    
    return sigma_epsilon

In [None]:
# --- 2. E-Step Integrands (Numerator and Denominator) ---

def integrand_den(u, d, alpha, c, sigma_eps):
    """
    Integrand for the denominator of E[U|d] and E[log(U)|d] (Eq. 25 & 26).
    
    Note: The original paper uses K_nu(z) in the prior term, but the equations 
    (25 & 26) have been simplified into the shown exponential/power form 
    specific to the BKF distribution in a GSM form. We use the simplified form here.
    """
    # u is the integration variable (latent scale variable)
    # d is the current noisy coefficient
    # The term u+sigma_sq comes from the Gaussian Likelihood and the conditional variance.
    
    # Exponent term (Gaussian Likelihood and Power Term from the Gamma/Inverse Gamma part of BKF)
    exponent = (d**2 / (2 * (u + sigma_eps**2))) - (u / c)
    
    # Power term from the prior distribution
    power_term = u**(alpha - 1)
    
    # Denominator term (Likelihood and conditional variance)
    denom_term = np.sqrt(u + sigma_eps**2)
    
    # The function being integrated
    return (power_term / denom_term) * np.exp(exponent)


def integrand_num_mu(u, d, alpha, c, sigma_eps):
    exponent = (d**2 / (2 * (u + sigma_eps**2))) - (u / c)
    power_term = u**(alpha)
    denom_term = np.sqrt(u + sigma_eps**2)
    return (power_term / denom_term) * np.exp(exponent)


def integrand_num_phi(u, d, alpha, c, sigma_sq):
    #conditional for log(0) = -inf. 
    if u == 0:
        return 0.0 
    #its just the denominator multiplied bylog(u)
    return np.log(u) * integrand_den(u, d, alpha, c, sigma_sq)


def e_step_integral(d, alpha, c, sigma_epsilon):
 
    
    I_den, err_den = quad(integrand_den, 0, np.inf, args=(d, alpha, c, sigma_epsilon))
    if I_den <= 1e-15: 
        return 0.0, 0.0 # Avoid division by zero
        
    # integrate numerator for M_U(1)
    I_num_mu, err_num_mu = quad(integrand_num_mu, 0, np.inf, args=(d, alpha, c, sigma_epsilon))
    
    # integrate the numerator for phi_U(1)
    I_num_phi, err_num_phi = quad(integrand_num_phi, 0, np.inf, args=(d, alpha, c, sigma_epsilon))
    
    M_U_1 = I_num_mu / I_den
    phi_U_1 = I_num_phi / I_den
    
    return M_U_1, phi_U_1

# Vectorize the integral function to apply it to all coefficients in a subband
v_e_step_integral = np.vectorize(e_step_integral)

In [None]:
# M-Step and the Full EM Loop

def m_step_alpha_root_func(alpha, phi_U_avg, M_U_avg):
    if alpha <= 0:
        return np.inf 
        
    # LHS of Eq. 28 
    LHS = psi(alpha) - np.log(alpha)
    # RHS of Eq. 28 (we are trying to find roots)
    RHS = phi_U_avg - np.log(M_U_avg)

    return LHS - RHS

def em_bessel_k_form_noisy(d_subband, alpha_init, c_init, sigma_epsilon, max_iter=100, tol=1e-5):
    """
    The main EM algorithm for a single wavelet subband (d).
    
    Args:
        d_subband: noisy wavelet coefficients in a single subband (2D numpy array).
        alpha_init, c_init: Initial guesses for alpha and c.
        sigma_epsilon: Estimated noise standard deviation from the MAD formula 
        
    """
    m = len(d_subband)
    alpha = alpha_init
    c = c_init
    
    # Flatten the subband data for iteration
    d_flat = d_subband.ravel() 
    
    for t in range(max_iter):
        alpha_prev = alpha
        c_prev = c
        
        # e step 
        M_U_i, phi_U_i = v_e_step_integral(d_flat, alpha, c, sigma_epsilon)
        
        M_U_sum = np.sum(M_U_i)
        phi_U_sum = np.sum(phi_U_i)
        
        M_U_avg = M_U_sum / m
        phi_U_avg = phi_U_sum / m

        # m step 
        
        # Solve for alpha^(t+1) using root-finding
        a = max(alpha * 0.1, 0.01) 
        b = alpha * 10.0
        
        try:
            alpha_next = brentq(
                m_step_alpha_root_func, 
                a=a, b=b, 
                args=(phi_U_avg, M_U_avg)
            )
        except Exception:
            # Fallback if root finding fails 
            alpha_next = alpha
            
        # 2. Solve for c^(t+1) (eq. 27)
        c_next = (1 / alpha_next) * M_U_avg
        
        # --- Check for Convergence ---
        alpha_diff = np.abs(alpha_next - alpha)
        c_diff = np.abs(c_next - c)
        
        alpha, c = alpha_next, c_next
        
        if alpha_diff < tol and c_diff < tol:
            # print(f"EM converged after {t+1} iterations.")
            break
            
    return alpha, c

In [None]:
#closed form expressions

def A(d, c, sigma_eps):
    return np.abs(d) - np.sqrt(2/c)*(sigma_eps**2)

def lambda_func(c, alpha,sigma_eps):
    factor = np.sqrt(2) * sigma_eps
    first_term = np.sqrt(2*(1-alpha))
    second_term = sigma_eps / np.sqrt(c)
    return factor * (first_term + second_term)

def map_estimator(d,A,alpha,sigma_eps, lambda_func):
    if np.abs(d) <= lambda_func:
        return 0.0
    else:
        return np.sign(d)/2 * (A+np.sqrt(A**2 + 4*sigma_eps**2*(alpha-1)))
# Vectorize the MAP estimator for array inputs
v_map_estimator = np.vectorize(map_estimator)