In [34]:
test_signal = np.array([2, 1,3,1,3,4,5,6, 2,3,4,5,])
dec_lo = get_filters('db4')[0]
dec_hi = get_filters('db4')[1]
rec_lo = get_filters('db4')[2]
rec_hi = get_filters('db4')[3]

first_level_approx_our = convolve_and_downsample(test_signal, dec_lo)
first_level_detail_our = convolve_and_downsample(test_signal, dec_hi)

first_level_approx_pywt = pywt.dwt(test_signal, 'db4', mode='periodization')[0]
first_level_detail_pywt = pywt.dwt(test_signal, 'db4', mode='periodization')[1]

recon_our = upsample_and_convolve(first_level_approx_our, rec_lo, len(test_signal)) + upsample_and_convolve(first_level_detail_our, rec_hi, len(test_signal))
recon_pywt = pywt.idwt(first_level_approx_pywt, first_level_detail_pywt, wavelet='db4', mode='periodization')

print(recon_our)
print(recon_pywt)
print(first_level_approx_our)
print(first_level_detail_our)
print(first_level_approx_pywt)
print(first_level_detail_pywt)









[2. 1. 3. 1. 3. 4. 5. 6. 2. 3. 4. 5.]
[2. 1. 3. 1. 3. 4. 5. 6. 2. 3. 4. 5.]
[6.55553616 3.11254604 2.50401791 3.75425507 7.93054252 3.72026676]
[-1.53196066  0.26006393  1.04468426  0.74265746  0.8583161  -0.66665429]
[6.55553616 3.11254604 2.50401791 3.75425507 7.93054252 3.72026676]
[-1.53196066  0.26006393  1.04468426  0.74265746  0.8583161  -0.66665429]


In [49]:
import numpy as np
import pywt

# =============================================================================
# CONFIGURATION - Change wavelet type here
# =============================================================================
WAVELET_TYPE = 'db3'  # Options: 'haar', 'db2', 'db4', 'sym4', etc.

# =============================================================================
# GET FILTER COEFFICIENTS FROM PYWT
# =============================================================================
def get_filters(wavelet_name: str):
    """
    Get analysis and synthesis filter coefficients from PyWavelets.
    
    Returns:
        dec_lo: Decomposition (analysis) lowpass filter
        dec_hi: Decomposition (analysis) highpass filter  
        rec_lo: Reconstruction (synthesis) lowpass filter
        rec_hi: Reconstruction (synthesis) highpass filter
    """
    wavelet = pywt.Wavelet(wavelet_name)
    dec_lo = np.array(wavelet.dec_lo)
    dec_hi = np.array(wavelet.dec_hi)
    rec_lo = np.array(wavelet.rec_lo)
    rec_hi = np.array(wavelet.rec_hi)
    return dec_lo, dec_hi, rec_lo, rec_hi


# =============================================================================
# CORE OPERATIONS: PERIODIC EXTENSION + CONVOLUTION + DOWNSAMPLING
# =============================================================================
def periodic_convolve(signal: np.ndarray, filter_coeffs: np.ndarray) -> np.ndarray:
    """
    Perform periodic (circular) convolution.
    
    Args:
        signal: Input signal
        filter_coeffs: Filter coefficients
        
    Returns:
        Convolution result with same length as signal (periodic boundary)
    """
    N = len(signal)
    M = len(filter_coeffs)
    result = np.zeros(N)
    
    for n in range(N):
        for k in range(M):
            # Periodic indexing: wrap around using modulo
            idx = (n - k) % N
            result[n] += filter_coeffs[k] * signal[idx]
    
    return result


def convolve_and_downsample(signal: np.ndarray, filter_coeffs: np.ndarray) -> np.ndarray:
    """
    Analysis operation: periodic convolution followed by downsampling by 2.
    
    Args:
        signal: Input signal
        filter_coeffs: Analysis filter coefficients
        
    Returns:
        Downsampled filtered signal (half the length)
    """
    # Step 1: Periodic convolution
    convolved = periodic_convolve(signal, filter_coeffs)

    # Step 2: Add the shift to bring index 0 to the beginning
    shift = len(filter_coeffs) //2
    convolved = np.roll(convolved, -shift)

    # Step 3: Downsample by 2 (keep even indices)
    downsampled = convolved[::2]
    
    return downsampled


def upsample_and_convolve(signal: np.ndarray, filter_coeffs: np.ndarray, target_length: int = None) -> np.ndarray:
    """
    Synthesis operation: upsample by 2 then periodic convolution.
    
    Args:
        signal: Input signal (coefficients)
        filter_coeffs: Synthesis filter coefficients
        target_length: Target output length (if None, uses 2 * len(signal))
        
    Returns:
        Upsampled and filtered signal
    """

    #Shift the filter coefficients to the right
    

    if target_length is None:
        target_length = len(signal) * 2
    
    # Step 1: Upsample by 2 (insert zeros at odd indices)
    upsampled = np.zeros(target_length)
    upsampled[::2] = signal
    
    # Step 2: Periodic convolution
    convolved = periodic_convolve(upsampled, filter_coeffs)
    shift = len(filter_coeffs) //2 -1
    convolved = np.roll(convolved, -shift)
    return convolved


# =============================================================================
# FAST WAVELET TRANSFORM (FWT) - ANALYSIS
# =============================================================================
def fwt(signal: np.ndarray, wavelet_name: str, level: int, verbose: bool = True) -> dict:
    """
    Fast Wavelet Transform using filter banks with periodic extension.
    
    Args:
        signal: Input signal (length should be power of 2 for clean decomposition)
        wavelet_name: Name of wavelet
        level: Number of decomposition levels
        verbose: Print intermediate steps
        
    Returns:
        Dictionary with 'approx' and 'details' at each level
    """
    # Get decomposition filters
    dec_lo, dec_hi, _, _ = get_filters(wavelet_name)
    # Pad the signal to the smallest power of 2 greater than signal length
    orig_len = len(signal)
    
    if orig_len % 2 == 1:
        signal = np.pad(signal, (0, 1), mode='constant')
        if verbose:
            print(f"Input signal padded from length {orig_len} to {len(signal)}.")

    if verbose:
        print(f"\n{'='*60}")
        print(f"FWT DECOMPOSITION - {wavelet_name.upper()}")
        print(f"{'='*60}")
        print(f"Signal length: {len(signal)}")
        print(f"Decomposition levels: {level}")
        print(f"Filter length: {len(dec_lo)}")
        print(f"dec_lo: {dec_lo}")
        print(f"dec_hi: {dec_hi}")
    
    # Initialize
    approx = signal.copy()
    details = []
    approx_history = [approx.copy()]
    
    # Iterative decomposition
    for j in range(level):
        if verbose:
            print(f"\n--- Level {j+1} ---")
            print(f"  Input (approx from level {j}): {approx}")
            print(f"  Input length: {len(approx)}")
        
        # Analysis filter bank
        # Lowpass branch -> approximation coefficients
        new_approx = convolve_and_downsample(approx, dec_lo)
        
        # Highpass branch -> detail coefficients
        detail = convolve_and_downsample(approx, dec_hi)
        
        if verbose:
            print(f"  After convolution & downsample:")
            print(f"    New approx (len={len(new_approx)}): {new_approx}")
            print(f"    Detail (len={len(detail)}): {detail}")
        
        # Store detail and update approx
        details.append(detail)
        approx = new_approx
        approx_history.append(approx.copy())
    
    result = {
        'approx': approx,
        'details': details,  # details[0] = finest, details[-1] = coarsest
        'level': level,
        'wavelet': wavelet_name,
        'original_length': len(signal)
    }
    
    if verbose:
        print(f"\n--- Final Coefficients ---")
        print(f"  Approximation (level {level}): {approx}")
        for i, d in enumerate(details):
            print(f"  Detail level {i+1}: {d}")
    
    return result


# =============================================================================
# INVERSE FAST WAVELET TRANSFORM (IFWT) - SYNTHESIS
# =============================================================================
def ifwt(coeffs: dict, verbose: bool = True) -> np.ndarray:
    """
    Inverse Fast Wavelet Transform - reconstruct signal from coefficients.
    
    Args:
        coeffs: Dictionary from fwt()
        verbose: Print intermediate steps
        
    Returns:
        Reconstructed signal
    """
    # Get reconstruction filters
    _, _, rec_lo, rec_hi = get_filters(coeffs['wavelet'])
    
    level = coeffs['level']
    approx = coeffs['approx'].copy()
    details = coeffs['details']
    
    if verbose:
        print(f"\n{'='*60}")
        print(f"IFWT RECONSTRUCTION - {coeffs['wavelet'].upper()}")
        print(f"{'='*60}")
        print(f"Reconstruction levels: {level}")
        print(f"Filter length: {len(rec_lo)}")
        print(f"rec_lo: {rec_lo}")
        print(f"rec_hi: {rec_hi}")
    
    # Iterative reconstruction (from coarsest to finest)
    for j in range(level - 1, -1, -1):
        detail = details[j]
        
        if verbose:
            print(f"\n--- Reconstructing level {j+1} ---")
            print(f"  Approx input (len={len(approx)}): {approx}")
            print(f"  Detail input (len={len(detail)}): {detail}")
        
        # Target length is based on detail (which represents original signal length at this level)
        target_length = len(detail) * 2
        
        # Synthesis filter bank
        # Upsample and convolve approximation with lowpass
        approx_up = upsample_and_convolve(approx, rec_lo, target_length)
        
        # Upsample and convolve detail with highpass  
        detail_up = upsample_and_convolve(detail, rec_hi, target_length)
        
        if verbose:
            print(f"  After upsample & convolve:")
            print(f"    Approx contribution (len={len(approx_up)}): {approx_up}")
            print(f"    Detail contribution (len={len(detail_up)}): {detail_up}")
        
        # Sum the contributions
        approx = approx_up + detail_up
        
        if verbose:
            print(f"  Combined (len={len(approx)}): {approx}")
    

        
    return approx


# =============================================================================
# TEST FUNCTION
# =============================================================================
import itertools

def test_fwt_ifwt_suite():
    """
    Generalized, parametrized test suite for FWT and IFWT.
    Runs multiple signals, wavelets, and levels.
    Only prints output for successful tests.
    """
    # Complex/general signals to test
    signals = [
        np.array([1.0, 4.0, -3.0, 0.0, 24, 5, 0, 0]),
        np.linspace(-10, 10, 16),  # Even length, increasing
        np.sin(np.linspace(0, 6*np.pi, 31)),  # Odd length, sinusoid
        np.random.RandomState(42).randn(32),   # Random noise
        np.concatenate([np.ones(8), np.zeros(8), np.arange(8)])  # Piecewise
    ]
    # Include diverse wavelets
    wavelets = [
        "db1", "db2", "db4", "sym2", "sym4", "coif1", "bior1.3", "haar"
    ]
    # Test for level = 1, 2 and 3 where possible
    levels = [1, 2, 3]

    passed = 0
    total = 0
    for sig_idx, signal in enumerate(signals):
        for wv in wavelets:
            # Determine maximum decomposition level valid for this signal and wavelet
            try:
                max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wv).dec_len)
            except Exception:
                continue  # Skip invalid wavelet/signal combos
            for lvl in levels:
                if lvl > max_level or lvl < 1:
                    continue
                total += 1
                try:
                    coeffs = fwt(signal, wv, lvl, verbose=False)
                    reconstructed = ifwt(coeffs, verbose=False)
                    # odd signals are padded with one zero at the end, so we need to remove it to match the original signal length
                    if len(signal) % 2 == 1:
                        reconstructed = reconstructed[:len(signal)]
                    ##print(f"len(reconstructed): {len(reconstructed)}, len(signal): {len(signal)}")
                    #print(f"reconstructed: {reconstructed}, signal: {signal}")
                    fwt_error = np.max(np.abs(signal - reconstructed))
                    # Check against PyWavelets also
                    pywt_coeffs = pywt.wavedec(signal, wv, mode='periodization', level=lvl)
                    pywt_approx = pywt_coeffs[0]
                    pywt_details = pywt_coeffs[1:]
                    #print(f"(pywt_approx): {pywt_approx}, coeffs['approx']: {coeffs['approx']}")
                    approx_match = np.allclose(coeffs['approx'][:len(pywt_approx)], pywt_approx, atol=1e-10)
                    #print("this point is reached")
                    #print(f"(coeffs['details']): {coeffs['details']}, (pywt_details): {pywt_details}")
                    details_match = all(
                        np.allclose(coeffs['details'][::-1][i][:len(pywt_details[i])], pywt_details[i], atol=1e-10)
                        for i in range(lvl)
                    )
                    pywt_recon = pywt.waverec(pywt_coeffs, wv, mode='periodization')
                    #print(f"len(pywt_recon): {len(pywt_recon)}, len(reconstructed): {len(reconstructed)}")
                    pywt_error = np.max(np.abs(signal - pywt_recon[:len(signal)]))
                    if fwt_error < 1e-7 and approx_match and details_match and pywt_error < 1e-7:
                        print(f"[PASS] sig#{sig_idx+1} | wavelet={wv}, level={lvl} | max_error={fwt_error:.2e}")
                        passed += 1
                    else:
                        print(f"[FAIL] sig#{sig_idx+1} | wavelet={wv}, level={lvl} | "
                              f"fwt_error={fwt_error:.2e}, pywt_error={pywt_error:.2e}, "
                              f"approx_match={approx_match}, details_match={details_match}")

                    
                except Exception as e:
                    print(f"Failed on sig#{sig_idx+1} | wavelet={wv}, level={lvl} due to {e}")
                    continue
    print(f"\n{passed}/{total} test cases passed.")



# =============================================================================
# RUN TESTS
# =============================================================================
# Test signal
test_signal = np.array([1.0, 4.0, -3.0, 0.0,24,5,0 ,0])

# Test with configured wavelet
test_fwt_ifwt_suite()


[PASS] sig#1 | wavelet=db1, level=1 | max_error=3.55e-15
[PASS] sig#1 | wavelet=db1, level=2 | max_error=3.55e-15
[PASS] sig#1 | wavelet=db1, level=3 | max_error=3.55e-15
[PASS] sig#1 | wavelet=db2, level=1 | max_error=1.78e-15
[PASS] sig#1 | wavelet=sym2, level=1 | max_error=1.46e-11
[PASS] sig#1 | wavelet=haar, level=1 | max_error=3.55e-15
[PASS] sig#1 | wavelet=haar, level=2 | max_error=3.55e-15
[PASS] sig#1 | wavelet=haar, level=3 | max_error=3.55e-15
[PASS] sig#2 | wavelet=db1, level=1 | max_error=3.55e-15
[PASS] sig#2 | wavelet=db1, level=2 | max_error=5.33e-15
[PASS] sig#2 | wavelet=db1, level=3 | max_error=5.33e-15
[PASS] sig#2 | wavelet=db2, level=1 | max_error=3.55e-15
[PASS] sig#2 | wavelet=db2, level=2 | max_error=3.55e-15
[PASS] sig#2 | wavelet=db4, level=1 | max_error=3.55e-15
[PASS] sig#2 | wavelet=sym2, level=1 | max_error=6.12e-12
[PASS] sig#2 | wavelet=sym2, level=2 | max_error=1.40e-11
[PASS] sig#2 | wavelet=sym4, level=1 | max_error=5.27e-12
[PASS] sig#2 | wavelet=c